!29997 [Auto parallel] [MoE] Support data_parallel + expert_parallel

Merge pull request !29997 from Xiaoda/124-moe-changes
This commit is contained in:
i-robot 2022-02-24 09:23:47 +00:00 committed by Gitee
commit bbcfbce9e0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 384 additions and 106 deletions

View File

@ -10,6 +10,7 @@
- **data_parallel** (int) - 表示数据并行数。默认值1。
- **model_parallel** (int) - 表示模型并行数。默认值1。
- **expert_parallel** (int) - 表示专家并行数只有在应用混合专家结构MoEMixture of Experts时才会生效。默认值1.
- **pipeline_stage** (int) - 表示将Transformer切分成的stage数目。其值应为正数。默认值1。
- **micro_batch_num** (int) - 表示用于pipeline训练的batch的微型大小。默认值1。
- **optimizer_shard** (bool) - 表示是否使能优化器切分。默认值False。

View File

@ -300,6 +300,8 @@ class _Linear(Cell):
eg. 'ReLU'.Default: None.
expert_num (int): The number of experts used in this Linear. Here, for the case expert_num > 1, BatchMatMul is
used and the first dimension in BatchMatMul indicate expert_num. Default: 1.
outer_batch (int): The replication number of experts. The replication is effective only when MoE is applied.
Default: 1.
compute_dtype (dtype.Number): The computation type. Default: mstype.float16
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal
@ -328,6 +330,7 @@ class _Linear(Cell):
activation=_valid_type_checks([type(None), str], "Linear"),
transpose_b=Validator.check_bool,
expert_num=Validator.check_positive_int,
outer_batch=Validator.check_positive_int,
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
"Linear"),
compute_dtype=_valid_value_checks([mstype.float32, mstype.float16],
@ -341,6 +344,7 @@ class _Linear(Cell):
activation=None,
transpose_b=True,
expert_num=1,
outer_batch=1,
param_init_type=mstype.float32,
compute_dtype=mstype.float16):
super(_Linear, self).__init__()
@ -351,6 +355,7 @@ class _Linear(Cell):
raise ValueError("The shape of parameter 'weight_init' is error, please check shape of 'weight_init'.")
weight_shape = [out_channels, in_channels] if transpose_b else [in_channels, out_channels]
self.expert_num = expert_num
self.outer_batch = outer_batch
if self.expert_num > 1:
self.expert_flag = True
self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape, param_init_type),
@ -377,7 +382,7 @@ class _Linear(Cell):
out_shape = P.Shape()(x)[:-1] + (self.out_channels,)
x = P.Reshape()(x, (-1, self.in_channels))
if self.expert_flag:
x = P.Reshape()(x, (self.expert_num, -1, self.in_channels))
x = P.Reshape()(x, (self.outer_batch, self.expert_num, -1, self.in_channels))
weight = self.cast(self.weight, self.dtype)
x = self.matmul(x, weight)
if self.has_bias:

View File

@ -23,7 +23,7 @@ from mindspore.ops import functional as F
from mindspore.nn import Cell
from mindspore.nn.loss.loss import _check_is_tensor
from .layers import _check_input_dtype, _check_input_shape
from .op_parallel_config import default_dpmp_config, OpParallelConfig
from .op_parallel_config import default_dpmp_config, OpParallelConfig, MoEParallelConfig
__all__ = ["CrossEntropyLoss"]
@ -33,8 +33,8 @@ class CrossEntropyLoss(Cell):
Calculate the cross entropy loss.
Args:
parallel_config (OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.
parallel_config (OpParallelConfig, MoEParallelConfig): The parallel configure. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.
Inputs:
- **logits** (Tensor) - Tensor of shape (N, C). Data type must be float16 or float32. The output logits of
@ -65,13 +65,16 @@ class CrossEntropyLoss(Cell):
def __init__(self, parallel_config=default_dpmp_config):
super(CrossEntropyLoss, self).__init__()
if not isinstance(parallel_config, OpParallelConfig):
raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig, "
"but got the type: {}.".format(type(parallel_config)))
if not isinstance(parallel_config, OpParallelConfig) and not isinstance(parallel_config, MoEParallelConfig):
raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig"
" or MoEParallelConfig, but got the type: {}.".format(type(parallel_config)))
dp = parallel_config.data_parallel
mp = parallel_config.model_parallel
self.sum = P.ReduceSum().shard(((dp, mp),))
self.onehot = P.OneHot().shard(((dp, mp), (), ()))
if isinstance(parallel_config, MoEParallelConfig):
self.onehot = P.OneHot().shard(((dp, mp*parallel_config.expert_parallel), (), ()))
else:
self.onehot = P.OneHot().shard(((dp, mp), (), ()))
# on/off value for onehot, for smooth labeling, modify the off_value
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)

View File

@ -19,13 +19,14 @@ import math
import numpy as np
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
from mindspore._checkparam import Validator
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore.nn.cell import Cell
from mindspore.nn.layer import Dense
from .op_parallel_config import default_dpmp_config
from .op_parallel_config import default_moeparallel_config
__all__ = [
"MoEConfig"]
@ -75,14 +76,29 @@ default_moe_config = MoEConfig()
def _check_moe_config(moe_config=None, parallel_config=None):
"""
check if MoE with right configuration.
"""
if not isinstance(moe_config, MoEConfig):
raise TypeError(f"'moe_config' should be an instance of MoEConfig, but got {type(moe_config).__name__}.")
use_moe = (moe_config.expert_num > 1)
if use_moe and moe_config.expert_num % parallel_config.data_parallel != 0:
if use_moe is False:
return
if moe_config.expert_num % parallel_config.expert_parallel != 0:
raise ValueError(f"When using MoE, the 'expert_num' in {type(moe_config).__name__} must be a multiple "
f"of 'data_parallel' value in {type(parallel_config).__name__}, but got "
f"{moe_config.expert_num} for 'expert_num' and {parallel_config.data_parallel} for "
f"'data_parallel'.")
f"of 'expert_parallel' value in {type(parallel_config).__name__}, but got "
f"{moe_config.expert_num} for 'expert_num' and {parallel_config.expert_parallel} for "
f"'expert_parallel'.")
device_num = D.get_group_size()
if device_num % parallel_config.expert_parallel != 0:
raise ValueError(f"device_num: {device_num} should be a multiple of expert_parallel: "
f"{parallel_config.expert_parallel}.")
if parallel_config.data_parallel * parallel_config.model_parallel * parallel_config.expert_parallel > device_num:
raise ValueError(f"The product of the data parallel: {parallel_config.data_parallel}, "
f"model parallel: {parallel_config.model_parallel}, and "
f"expert parallel: {parallel_config.expert_parallel} "
f"should be less than device_num: {device_num}.")
@constexpr
@ -106,9 +122,8 @@ class MoE(Cell):
param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with
default values. Please see `MoEConfig`.
parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`.
Default `default_dpmp_config`, an instance of `OpParallelConfig` with default
args.
parallel_config(MoEParallelConfig): The parallel config for MoE, see `MoEParallelConfig`.
Default `default_moeparallel_config`, an instance of `MoEParallelConfig` with default args.
Inputs:
- **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor.
@ -122,7 +137,7 @@ class MoE(Cell):
hidden_act='gelu',
param_init_type=mstype.float32,
moe_config=default_moe_config,
parallel_config=default_dpmp_config):
parallel_config=default_moeparallel_config):
super(MoE, self).__init__()
self.hidden_size = hidden_size
self.expert_dim = moe_config.expert_num

View File

@ -39,6 +39,57 @@ class _Config:
return info
class MoEParallelConfig(_Config):
r"""
MoEParallelConfig for MoE structure, which includes setting data parallel, model parallel and expert parallel.
Args:
data_parallel (int): The data parallel way. Default: 1
model_parallel (int): The model parallel way. Default: 1
expert_parallel (int): The expert parallel way. Default: 1
Supported Platforms:
``Ascend`` ``GPU``
"""
def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1):
Validator.check_positive_int(data_parallel, "data_parallel")
Validator.check_positive_int(model_parallel, "model_parallel")
Validator.check_positive_int(expert_parallel, "expert_parallel")
self._dpmp = OpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel)
self.expert_parallel = expert_parallel
@property
def data_parallel(self):
return self._dpmp.data_parallel
@data_parallel.setter
def data_parallel(self, value):
Validator.check_positive_int(value, "data_parallel")
self._dpmp.data_parallel = value
@property
def model_parallel(self):
return self._dpmp.model_parallel
@model_parallel.setter
def model_parallel(self, value):
Validator.check_positive_int(value, "model_parallel")
self._dpmp.model_parallel = value
@property
def expert_parallel(self):
return self._expert_parallel
@expert_parallel.setter
def expert_parallel(self, value):
Validator.check_positive_int(value, "expert_parallel")
self._expert_parallel = value
@property
def dpmp(self):
return self._dpmp
class OpParallelConfig(_Config):
r"""
OpParallelConfig for the setting data parallel and model parallel.
@ -121,6 +172,7 @@ class _PipeLineConfig(_Config):
# In case the user doesn't pass a config as args.
default_dpmp_config = OpParallelConfig()
default_moeparallel_config = MoEParallelConfig()
def _check_config(config):

View File

@ -34,7 +34,8 @@ from mindspore.context import ParallelMode
from .layers import _LayerNorm, _Linear, _Dropout, _check_input_shape, \
_args_type_validator_check, _valid_type_checks, _valid_value_checks, \
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value
from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config
from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config, \
MoEParallelConfig
from .moe import default_moe_config, MoE, _check_moe_config
__all__ = [
@ -212,6 +213,8 @@ class TransformerOpParallelConfig(_Config):
according to the data parallel way. Default: 1.
model_parallel (int): The model parallel way. The parameters of dense layers in MultiheadAttention and
FeedForward layer will be sliced according to the model parallel way. Default: 1.
expert_parallel (int): The expert parallel way. This is effective only when MoE (Mixture of Experts) is applied.
This value specifies the number of partitions to split the experts into.
pipeline_stage (int): The number of the pipeline stage. Should be a positive value. Default: 1.
micro_batch_num (int): The micro size of the batches for the pipeline training. Default: 1.
optimizer_shard (bool): Whether to enable optimizer shard. Default False.
@ -230,7 +233,7 @@ class TransformerOpParallelConfig(_Config):
>>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, recompute=recompute_config)
"""
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1,
def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1, pipeline_stage=1, micro_batch_num=1,
recompute=default_transformer_recompute_config,
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
self.recompute = recompute
@ -239,6 +242,8 @@ class TransformerOpParallelConfig(_Config):
self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
vocab_emb_dp=vocab_emb_dp)
self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num)
self._moe_config = MoEParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
expert_parallel=expert_parallel)
@property
def recompute(self):
@ -284,6 +289,7 @@ class TransformerOpParallelConfig(_Config):
@model_parallel.setter
def model_parallel(self, value):
self._embed_dp_mp_config.model_parallel = value
self._moe_config.model_parallel = value
@property
def data_parallel(self):
@ -292,6 +298,15 @@ class TransformerOpParallelConfig(_Config):
@data_parallel.setter
def data_parallel(self, value):
self._embed_dp_mp_config.data_parallel = value
self._moe_config.data_parallel = value
@property
def expert_parallel(self):
return self._moe_config.expert_parallel
@expert_parallel.setter
def expert_parallel(self, value):
self._moe_config.expert_parallel = value
@property
def pipeline_stage(self):
@ -342,6 +357,10 @@ class TransformerOpParallelConfig(_Config):
"""
return self._embed_dp_mp_config.dp_mp_config
@property
def moe_parallel_config(self):
return self._moe_config
default_transformer_config = TransformerOpParallelConfig()
default_embedding_parallel_config = EmbeddingOpParallelConfig()
@ -370,9 +389,9 @@ class FeedForward(Cell):
and the first dimension in BatchMatMul indicate expert_num. Default: 1.
param_init_type: (dtype.Number): The parameter initialization type. Should be dtype.float32 or dtype.float16.
Default: dtype.float32.
parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`.
Default `default_dpmp_config`, an instance of `OpParallelConfig` with
default args.
parallel_config (OpParallelConfig, MoEParallelConfig): The config of parallel setting, see `OpParallelConfig` or
`MoEParallelConfig`. When MoE is applied, MoEParallelConfig is effective, otherwise OpParallelConfig is
effective. Default `default_dpmp_config`, an instance of `OpParallelConfig` with default args.
Inputs:
- **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
@ -409,7 +428,7 @@ class FeedForward(Cell):
hidden_act=_valid_type_checks([str], "FeedForward"),
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
"FeedForward"),
parallel_config=_valid_type_checks([OpParallelConfig],
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig],
"FeedForward"))
def __init__(self, hidden_size,
ffn_hidden_size,
@ -422,6 +441,10 @@ class FeedForward(Cell):
_check_config(parallel_config)
dp = parallel_config.data_parallel
mp = parallel_config.model_parallel
if expert_num > 1:
ep = parallel_config.expert_parallel
else:
ep = 1
if ffn_hidden_size % mp != 0:
raise ValueError("For 'FeedForward', the class variable 'ffn_hidden_size' must be a multiple of the num of "
"model parallel, but got the ffn_hidden_size is {} and the num of model parallel is {}."
@ -435,20 +458,20 @@ class FeedForward(Cell):
"but got the value : {}.".format(dropout_rate))
input_size = hidden_size
output_size = ffn_hidden_size
# Here, 'ep' stands for expert parallel number, which is equal to data parallel number.
ep = dp
# Project to ffn_hidden_size
self.mapping = _Linear(in_channels=input_size,
out_channels=output_size,
activation=hidden_act,
transpose_b=False,
expert_num=expert_num,
outer_batch=dp,
param_init_type=param_init_type)
if expert_num > 1:
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)),
strategy_bias=((ep, 1, mp), (mp,)),
strategy_activation=((ep, 1, mp),))
self.mapping.shard(strategy_matmul=((dp, ep, 1, 1), (ep, 1, mp)),
strategy_bias=((dp, ep, 1, mp), (mp,)),
strategy_activation=((dp, ep, 1, mp),))
else:
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
strategy_bias=((dp, mp), (mp,)),
@ -458,10 +481,11 @@ class FeedForward(Cell):
out_channels=input_size,
transpose_b=False,
expert_num=expert_num,
outer_batch=dp,
param_init_type=param_init_type)
if expert_num > 1:
self.projection.shard(strategy_matmul=((ep, 1, mp), (ep, mp, 1)),
strategy_bias=((ep, 1, 1), (1,)))
self.projection.shard(strategy_matmul=((dp, ep, 1, mp), (ep, mp, 1)),
strategy_bias=((dp, ep, 1, 1), (1,)))
else:
self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)),
strategy_bias=((dp, 1), (1,)))
@ -470,6 +494,8 @@ class FeedForward(Cell):
self.dropout.shard(((dp, 1),))
self.dropout_3d = _Dropout(1 - dropout_rate)
self.dropout_3d.shard(((dp, 1, 1),))
self.dropout_4d = _Dropout(1 - dropout_rate)
self.dropout_4d.shard(((dp, ep, 1, 1),))
self.cast = P.Cast()
def construct(self, x):
@ -482,8 +508,10 @@ class FeedForward(Cell):
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
if len(F.shape(output)) == 3:
output = self.dropout_3d(output)
else:
elif len(F.shape(output)) == 2:
output = self.dropout(output)
else:
output = self.dropout_4d(output)
return output
@ -1164,7 +1192,8 @@ class TransformerEncoderLayer(Cell):
pass the single step's input tensor, and loop it. Default False.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with
default values. Please see `MoEConfig`.
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied,
MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.
Inputs:
@ -1253,7 +1282,7 @@ class TransformerEncoderLayer(Cell):
"TransformerEncoderLayer"),
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
"TransformerEncoderLayer"),
parallel_config=_valid_type_checks([OpParallelConfig],
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig],
"TransformerEncoderLayer"),
use_past=Validator.check_bool)
def __init__(self,
@ -1287,6 +1316,8 @@ class TransformerEncoderLayer(Cell):
"by the 'parallel_config.model_parallel', but got the ffn_hidden_size is {} "
"and parallel_config. model_parallel is {}."
.format(ffn_hidden_size, parallel_config.model_parallel))
_check_moe_config(moe_config, parallel_config)
self.use_moe = (moe_config.expert_num > 1)
self.use_past = use_past
self.seq_length = seq_length
self.hidden_size = hidden_size
@ -1295,20 +1326,30 @@ class TransformerEncoderLayer(Cell):
self.layernorm1.shard(((parallel_config.data_parallel, 1),))
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
self.layernorm2.shard(((parallel_config.data_parallel, 1),))
self.attention = MultiHeadAttention(batch_size=batch_size,
src_seq_length=seq_length,
tgt_seq_length=seq_length,
hidden_size=hidden_size,
num_heads=num_heads,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type,
use_past=use_past,
parallel_config=parallel_config)
_check_moe_config(moe_config, parallel_config)
self.use_moe = (moe_config.expert_num > 1)
if self.use_moe is True:
self.attention = MultiHeadAttention(batch_size=batch_size,
src_seq_length=seq_length,
tgt_seq_length=seq_length,
hidden_size=hidden_size,
num_heads=num_heads,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type,
use_past=use_past,
parallel_config=parallel_config.dpmp)
else:
self.attention = MultiHeadAttention(batch_size=batch_size,
src_seq_length=seq_length,
tgt_seq_length=seq_length,
hidden_size=hidden_size,
num_heads=num_heads,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type,
use_past=use_past,
parallel_config=parallel_config)
if self.use_moe:
self.output = MoE(hidden_size=hidden_size,
dropout_rate=hidden_dropout_rate,
@ -1480,7 +1521,8 @@ class TransformerDecoderLayer(Cell):
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with
default values. Please see `MoEConfig`.
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied,
MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.
Inputs:
@ -1553,7 +1595,7 @@ class TransformerDecoderLayer(Cell):
"TransformerDecoderLayer"),
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
"TransformerDecoderLayer"),
parallel_config=_valid_type_checks([OpParallelConfig],
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig],
"TransformerDecoderLayer"),
use_past=Validator.check_bool)
def __init__(self, hidden_size,
@ -1588,6 +1630,8 @@ class TransformerDecoderLayer(Cell):
"divisibled by 'parallel_config.model_parallel', but got the ffn_hidden_size is {} "
"and parallel_config.model_parallel is {}."
.format(ffn_hidden_size, parallel_config.model_parallel))
_check_moe_config(moe_config, parallel_config)
self.use_moe = (moe_config.expert_num > 1)
if use_past:
raise ValueError(f"The {self.cls_name} does not support use_past=True.")
self.batch_size = batch_size
@ -1603,35 +1647,59 @@ class TransformerDecoderLayer(Cell):
self.layernorm1.shard(((parallel_config.data_parallel, 1),))
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
self.layernorm2.shard(((parallel_config.data_parallel, 1),))
self.attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads,
batch_size=batch_size,
src_seq_length=tgt_seq_length,
tgt_seq_length=tgt_seq_length,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
use_past=use_past,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type,
parallel_config=parallel_config)
if self.use_moe is True:
self.attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads,
batch_size=batch_size,
src_seq_length=tgt_seq_length,
tgt_seq_length=tgt_seq_length,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
use_past=use_past,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type,
parallel_config=parallel_config.dpmp)
else:
self.attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads,
batch_size=batch_size,
src_seq_length=tgt_seq_length,
tgt_seq_length=tgt_seq_length,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
use_past=use_past,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type,
parallel_config=parallel_config)
# Cross attention with the output of encoder as memory tensor
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads,
batch_size=batch_size,
src_seq_length=tgt_seq_length,
tgt_seq_length=src_seq_length,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
softmax_compute_type=softmax_compute_type,
use_past=use_past,
param_init_type=param_init_type,
parallel_config=parallel_config)
if self.use_moe is True:
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads,
batch_size=batch_size,
src_seq_length=tgt_seq_length,
tgt_seq_length=src_seq_length,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
softmax_compute_type=softmax_compute_type,
use_past=use_past,
param_init_type=param_init_type,
parallel_config=parallel_config.dpmp)
else:
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads,
batch_size=batch_size,
src_seq_length=tgt_seq_length,
tgt_seq_length=src_seq_length,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
softmax_compute_type=softmax_compute_type,
use_past=use_past,
param_init_type=param_init_type,
parallel_config=parallel_config)
self.cross_attention_layernorm = _LayerNorm((hidden_size,)).to_float(
layernorm_compute_type)
self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1),))
_check_moe_config(moe_config, parallel_config)
self.use_moe = (moe_config.expert_num > 1)
if self.use_moe:
self.output = MoE(hidden_size=hidden_size,
dropout_rate=hidden_dropout_rate,
@ -2019,21 +2087,38 @@ class TransformerEncoder(Cell):
self.num_layers = num_layers
self.blocks = nn.CellList()
for i in range(num_layers):
block = TransformerEncoderLayer(hidden_size=hidden_size,
batch_size=batch_size,
ffn_hidden_size=ffn_hidden_size,
seq_length=seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
layernorm_compute_type=layernorm_compute_type,
softmax_compute_type=softmax_compute_type,
num_heads=num_heads,
hidden_act=hidden_act,
post_layernorm_residual=post_layernorm_residual,
param_init_type=param_init_type,
use_past=use_past,
moe_config=moe_config,
parallel_config=parallel_config.dp_mp_config)
if self.use_moe is True:
block = TransformerEncoderLayer(hidden_size=hidden_size,
batch_size=batch_size,
ffn_hidden_size=ffn_hidden_size,
seq_length=seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
layernorm_compute_type=layernorm_compute_type,
softmax_compute_type=softmax_compute_type,
num_heads=num_heads,
hidden_act=hidden_act,
post_layernorm_residual=post_layernorm_residual,
param_init_type=param_init_type,
use_past=use_past,
moe_config=moe_config,
parallel_config=parallel_config.moe_parallel_config)
else:
block = TransformerEncoderLayer(hidden_size=hidden_size,
batch_size=batch_size,
ffn_hidden_size=ffn_hidden_size,
seq_length=seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
layernorm_compute_type=layernorm_compute_type,
softmax_compute_type=softmax_compute_type,
num_heads=num_heads,
hidden_act=hidden_act,
post_layernorm_residual=post_layernorm_residual,
param_init_type=param_init_type,
use_past=use_past,
moe_config=moe_config,
parallel_config=parallel_config.dp_mp_config)
# If the user doesn't pass the fusion function, use the default one
if not lambda_func:
lambda_func = _get_lambda_func()
@ -2214,22 +2299,40 @@ class TransformerDecoder(Cell):
_check_moe_config(moe_config, parallel_config)
self.use_moe = (moe_config.expert_num > 1)
for i in range(num_layers):
block = TransformerDecoderLayer(hidden_size=hidden_size,
batch_size=batch_size,
ffn_hidden_size=ffn_hidden_size,
src_seq_length=src_seq_length,
tgt_seq_length=tgt_seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
num_heads=num_heads,
layernorm_compute_type=layernorm_compute_type,
softmax_compute_type=softmax_compute_type,
hidden_act=hidden_act,
use_past=use_past,
param_init_type=param_init_type,
post_layernorm_residual=post_layernorm_residual,
moe_config=moe_config,
parallel_config=parallel_config.dp_mp_config)
if self.use_moe is True:
block = TransformerDecoderLayer(hidden_size=hidden_size,
batch_size=batch_size,
ffn_hidden_size=ffn_hidden_size,
src_seq_length=src_seq_length,
tgt_seq_length=tgt_seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
num_heads=num_heads,
layernorm_compute_type=layernorm_compute_type,
softmax_compute_type=softmax_compute_type,
hidden_act=hidden_act,
use_past=use_past,
param_init_type=param_init_type,
post_layernorm_residual=post_layernorm_residual,
moe_config=moe_config,
parallel_config=parallel_config.moe_parallel_config)
else:
block = TransformerDecoderLayer(hidden_size=hidden_size,
batch_size=batch_size,
ffn_hidden_size=ffn_hidden_size,
src_seq_length=src_seq_length,
tgt_seq_length=tgt_seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
num_heads=num_heads,
layernorm_compute_type=layernorm_compute_type,
softmax_compute_type=softmax_compute_type,
hidden_act=hidden_act,
use_past=use_past,
param_init_type=param_init_type,
post_layernorm_residual=post_layernorm_residual,
moe_config=moe_config,
parallel_config=parallel_config.dp_mp_config)
# If the user doesn't pass the fusion function, use the default one
if not lambda_func:
lambda_func = _get_lambda_func()

View File

@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops as P
from mindspore import Tensor
from mindspore.context import set_auto_parallel_context, ParallelMode
from mindspore.ops import composite as C
from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig
from mindspore.ops import functional as F
from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig, CrossEntropyLoss
from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell, _VirtualDatasetCell
from mindspore.train import Model
@ -64,6 +67,11 @@ class NetWithLossFiveInputs(nn.Cell):
def test_transformer_model():
"""
Feature: Test Transformer+MoE, with All2All enabled.
Description: 3-dim input.
Expectation: Successful graph compilation with All2All included.
"""
set_auto_parallel_context(device_num=16, global_rank=0,
full_batch=True, enable_alltoall=True,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
@ -95,6 +103,11 @@ def test_transformer_model():
def test_transformer_model_2d():
"""
Feature: Test Transformer+MoE, with All2All enabled.
Description: 2-dim input.
Expectation: Successful graph compilation with All2All included.
"""
set_auto_parallel_context(device_num=16, global_rank=0,
full_batch=True, enable_alltoall=True,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
@ -124,3 +137,89 @@ def test_transformer_model_2d():
model = Model(net_with_grad)
model.train(1, dataset, dataset_sink_mode=False)
class TransformerNet(nn.Cell):
"""Transformer with loss"""
def __init__(self, en_layer, de_layer, parallel_config):
super(TransformerNet, self).__init__()
self.network = Transformer(encoder_layers=en_layer,
decoder_layers=de_layer,
batch_size=2,
src_seq_length=20,
tgt_seq_length=10,
hidden_size=64,
num_heads=8,
ffn_hidden_size=64,
moe_config=moe_config,
parallel_config=parallel_config)
self.loss = CrossEntropyLoss(parallel_config=config.moe_parallel_config)
def construct(self, x1, x2, x3, x4, x5, y, mask):
predict, _, _ = self.network(x1, x2, x3, x4, x5)
predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
return self.loss(predict, y, mask)
def moe_with_loss_plus_mutiparallel(local_parallel_config):
set_auto_parallel_context(device_num=16, enable_alltoall=True,
full_batch=True, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
label = Tensor(np.ones((20,)), mstype.int32)
input_mask = Tensor(np.ones((20,)), mstype.float32)
net = TransformerNet(en_layer=1, de_layer=1, parallel_config=local_parallel_config)
net = _VirtualDatasetCell(net)
params = net.trainable_params()
optimizer = AdamWeightDecay(params)
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
memory_mask, label, input_mask)
net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
model = Model(net_with_grad)
model.train(1, dataset, dataset_sink_mode=False)
def test_moe_expert_parallel1():
"""
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
Description: 3-dim input.
Expectation: Successful graph compilation with All2All included.
"""
local_p_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=4, expert_parallel=2)
moe_with_loss_plus_mutiparallel(local_p_config)
def test_moe_expert_parallel2():
"""
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
Description: 3-dim input.
Expectation: Successful graph compilation with All2All included.
"""
local_p_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, expert_parallel=1)
moe_with_loss_plus_mutiparallel(local_p_config)
def test_moe_expert_parallel3():
"""
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
Description: 3-dim input.
Expectation: Successful graph compilation.
"""
local_p_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, expert_parallel=2)
moe_with_loss_plus_mutiparallel(local_p_config)
def test_moe_expert_parallel_exception():
"""
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
Description: data_parallel*model_parallel*expert_parallel > device_num
Expectation: Raise ValueError.
"""
local_p_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, expert_parallel=4)
with pytest.raises(ValueError):
moe_with_loss_plus_mutiparallel(local_p_config)