forked from mindspore-Ecosystem/mindspore
!29997 [Auto parallel] [MoE] Support data_parallel + expert_parallel
Merge pull request !29997 from Xiaoda/124-moe-changes
This commit is contained in:
commit
bbcfbce9e0
|
@ -10,6 +10,7 @@
|
|||
|
||||
- **data_parallel** (int) - 表示数据并行数。默认值:1。
|
||||
- **model_parallel** (int) - 表示模型并行数。默认值:1。
|
||||
- **expert_parallel** (int) - 表示专家并行数,只有在应用混合专家结构(MoE,Mixture of Experts)时才会生效。默认值:1.
|
||||
- **pipeline_stage** (int) - 表示将Transformer切分成的stage数目。其值应为正数。默认值:1。
|
||||
- **micro_batch_num** (int) - 表示用于pipeline训练的batch的微型大小。默认值:1。
|
||||
- **optimizer_shard** (bool) - 表示是否使能优化器切分。默认值:False。
|
||||
|
|
|
@ -300,6 +300,8 @@ class _Linear(Cell):
|
|||
eg. 'ReLU'.Default: None.
|
||||
expert_num (int): The number of experts used in this Linear. Here, for the case expert_num > 1, BatchMatMul is
|
||||
used and the first dimension in BatchMatMul indicate expert_num. Default: 1.
|
||||
outer_batch (int): The replication number of experts. The replication is effective only when MoE is applied.
|
||||
Default: 1.
|
||||
compute_dtype (dtype.Number): The computation type. Default: mstype.float16
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal
|
||||
|
@ -328,6 +330,7 @@ class _Linear(Cell):
|
|||
activation=_valid_type_checks([type(None), str], "Linear"),
|
||||
transpose_b=Validator.check_bool,
|
||||
expert_num=Validator.check_positive_int,
|
||||
outer_batch=Validator.check_positive_int,
|
||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"Linear"),
|
||||
compute_dtype=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
|
@ -341,6 +344,7 @@ class _Linear(Cell):
|
|||
activation=None,
|
||||
transpose_b=True,
|
||||
expert_num=1,
|
||||
outer_batch=1,
|
||||
param_init_type=mstype.float32,
|
||||
compute_dtype=mstype.float16):
|
||||
super(_Linear, self).__init__()
|
||||
|
@ -351,6 +355,7 @@ class _Linear(Cell):
|
|||
raise ValueError("The shape of parameter 'weight_init' is error, please check shape of 'weight_init'.")
|
||||
weight_shape = [out_channels, in_channels] if transpose_b else [in_channels, out_channels]
|
||||
self.expert_num = expert_num
|
||||
self.outer_batch = outer_batch
|
||||
if self.expert_num > 1:
|
||||
self.expert_flag = True
|
||||
self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape, param_init_type),
|
||||
|
@ -377,7 +382,7 @@ class _Linear(Cell):
|
|||
out_shape = P.Shape()(x)[:-1] + (self.out_channels,)
|
||||
x = P.Reshape()(x, (-1, self.in_channels))
|
||||
if self.expert_flag:
|
||||
x = P.Reshape()(x, (self.expert_num, -1, self.in_channels))
|
||||
x = P.Reshape()(x, (self.outer_batch, self.expert_num, -1, self.in_channels))
|
||||
weight = self.cast(self.weight, self.dtype)
|
||||
x = self.matmul(x, weight)
|
||||
if self.has_bias:
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.nn import Cell
|
||||
from mindspore.nn.loss.loss import _check_is_tensor
|
||||
from .layers import _check_input_dtype, _check_input_shape
|
||||
from .op_parallel_config import default_dpmp_config, OpParallelConfig
|
||||
from .op_parallel_config import default_dpmp_config, OpParallelConfig, MoEParallelConfig
|
||||
|
||||
__all__ = ["CrossEntropyLoss"]
|
||||
|
||||
|
@ -33,8 +33,8 @@ class CrossEntropyLoss(Cell):
|
|||
Calculate the cross entropy loss.
|
||||
|
||||
Args:
|
||||
parallel_config (OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
|
||||
an instance of `OpParallelConfig` with default args.
|
||||
parallel_config (OpParallelConfig, MoEParallelConfig): The parallel configure. Default `default_dpmp_config`,
|
||||
an instance of `OpParallelConfig` with default args.
|
||||
|
||||
Inputs:
|
||||
- **logits** (Tensor) - Tensor of shape (N, C). Data type must be float16 or float32. The output logits of
|
||||
|
@ -65,13 +65,16 @@ class CrossEntropyLoss(Cell):
|
|||
|
||||
def __init__(self, parallel_config=default_dpmp_config):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
if not isinstance(parallel_config, OpParallelConfig):
|
||||
raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig, "
|
||||
"but got the type: {}.".format(type(parallel_config)))
|
||||
if not isinstance(parallel_config, OpParallelConfig) and not isinstance(parallel_config, MoEParallelConfig):
|
||||
raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig"
|
||||
" or MoEParallelConfig, but got the type: {}.".format(type(parallel_config)))
|
||||
dp = parallel_config.data_parallel
|
||||
mp = parallel_config.model_parallel
|
||||
self.sum = P.ReduceSum().shard(((dp, mp),))
|
||||
self.onehot = P.OneHot().shard(((dp, mp), (), ()))
|
||||
if isinstance(parallel_config, MoEParallelConfig):
|
||||
self.onehot = P.OneHot().shard(((dp, mp*parallel_config.expert_parallel), (), ()))
|
||||
else:
|
||||
self.onehot = P.OneHot().shard(((dp, mp), (), ()))
|
||||
# on/off value for onehot, for smooth labeling, modify the off_value
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
|
|
@ -19,13 +19,14 @@ import math
|
|||
import numpy as np
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.communication.management as D
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer import Dense
|
||||
from .op_parallel_config import default_dpmp_config
|
||||
from .op_parallel_config import default_moeparallel_config
|
||||
|
||||
__all__ = [
|
||||
"MoEConfig"]
|
||||
|
@ -75,14 +76,29 @@ default_moe_config = MoEConfig()
|
|||
|
||||
|
||||
def _check_moe_config(moe_config=None, parallel_config=None):
|
||||
"""
|
||||
check if MoE with right configuration.
|
||||
"""
|
||||
if not isinstance(moe_config, MoEConfig):
|
||||
raise TypeError(f"'moe_config' should be an instance of MoEConfig, but got {type(moe_config).__name__}.")
|
||||
use_moe = (moe_config.expert_num > 1)
|
||||
if use_moe and moe_config.expert_num % parallel_config.data_parallel != 0:
|
||||
if use_moe is False:
|
||||
return
|
||||
if moe_config.expert_num % parallel_config.expert_parallel != 0:
|
||||
raise ValueError(f"When using MoE, the 'expert_num' in {type(moe_config).__name__} must be a multiple "
|
||||
f"of 'data_parallel' value in {type(parallel_config).__name__}, but got "
|
||||
f"{moe_config.expert_num} for 'expert_num' and {parallel_config.data_parallel} for "
|
||||
f"'data_parallel'.")
|
||||
f"of 'expert_parallel' value in {type(parallel_config).__name__}, but got "
|
||||
f"{moe_config.expert_num} for 'expert_num' and {parallel_config.expert_parallel} for "
|
||||
f"'expert_parallel'.")
|
||||
|
||||
device_num = D.get_group_size()
|
||||
if device_num % parallel_config.expert_parallel != 0:
|
||||
raise ValueError(f"device_num: {device_num} should be a multiple of expert_parallel: "
|
||||
f"{parallel_config.expert_parallel}.")
|
||||
if parallel_config.data_parallel * parallel_config.model_parallel * parallel_config.expert_parallel > device_num:
|
||||
raise ValueError(f"The product of the data parallel: {parallel_config.data_parallel}, "
|
||||
f"model parallel: {parallel_config.model_parallel}, and "
|
||||
f"expert parallel: {parallel_config.expert_parallel} "
|
||||
f"should be less than device_num: {device_num}.")
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -106,9 +122,8 @@ class MoE(Cell):
|
|||
param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16.
|
||||
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with
|
||||
default values. Please see `MoEConfig`.
|
||||
parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`.
|
||||
Default `default_dpmp_config`, an instance of `OpParallelConfig` with default
|
||||
args.
|
||||
parallel_config(MoEParallelConfig): The parallel config for MoE, see `MoEParallelConfig`.
|
||||
Default `default_moeparallel_config`, an instance of `MoEParallelConfig` with default args.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor.
|
||||
|
@ -122,7 +137,7 @@ class MoE(Cell):
|
|||
hidden_act='gelu',
|
||||
param_init_type=mstype.float32,
|
||||
moe_config=default_moe_config,
|
||||
parallel_config=default_dpmp_config):
|
||||
parallel_config=default_moeparallel_config):
|
||||
super(MoE, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.expert_dim = moe_config.expert_num
|
||||
|
|
|
@ -39,6 +39,57 @@ class _Config:
|
|||
return info
|
||||
|
||||
|
||||
class MoEParallelConfig(_Config):
|
||||
r"""
|
||||
MoEParallelConfig for MoE structure, which includes setting data parallel, model parallel and expert parallel.
|
||||
|
||||
Args:
|
||||
data_parallel (int): The data parallel way. Default: 1
|
||||
model_parallel (int): The model parallel way. Default: 1
|
||||
expert_parallel (int): The expert parallel way. Default: 1
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
|
||||
def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1):
|
||||
Validator.check_positive_int(data_parallel, "data_parallel")
|
||||
Validator.check_positive_int(model_parallel, "model_parallel")
|
||||
Validator.check_positive_int(expert_parallel, "expert_parallel")
|
||||
self._dpmp = OpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel)
|
||||
self.expert_parallel = expert_parallel
|
||||
|
||||
@property
|
||||
def data_parallel(self):
|
||||
return self._dpmp.data_parallel
|
||||
|
||||
@data_parallel.setter
|
||||
def data_parallel(self, value):
|
||||
Validator.check_positive_int(value, "data_parallel")
|
||||
self._dpmp.data_parallel = value
|
||||
|
||||
@property
|
||||
def model_parallel(self):
|
||||
return self._dpmp.model_parallel
|
||||
|
||||
@model_parallel.setter
|
||||
def model_parallel(self, value):
|
||||
Validator.check_positive_int(value, "model_parallel")
|
||||
self._dpmp.model_parallel = value
|
||||
|
||||
@property
|
||||
def expert_parallel(self):
|
||||
return self._expert_parallel
|
||||
|
||||
@expert_parallel.setter
|
||||
def expert_parallel(self, value):
|
||||
Validator.check_positive_int(value, "expert_parallel")
|
||||
self._expert_parallel = value
|
||||
|
||||
@property
|
||||
def dpmp(self):
|
||||
return self._dpmp
|
||||
|
||||
|
||||
class OpParallelConfig(_Config):
|
||||
r"""
|
||||
OpParallelConfig for the setting data parallel and model parallel.
|
||||
|
@ -121,6 +172,7 @@ class _PipeLineConfig(_Config):
|
|||
|
||||
# In case the user doesn't pass a config as args.
|
||||
default_dpmp_config = OpParallelConfig()
|
||||
default_moeparallel_config = MoEParallelConfig()
|
||||
|
||||
|
||||
def _check_config(config):
|
||||
|
|
|
@ -34,7 +34,8 @@ from mindspore.context import ParallelMode
|
|||
from .layers import _LayerNorm, _Linear, _Dropout, _check_input_shape, \
|
||||
_args_type_validator_check, _valid_type_checks, _valid_value_checks, \
|
||||
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value
|
||||
from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config
|
||||
from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config, \
|
||||
MoEParallelConfig
|
||||
from .moe import default_moe_config, MoE, _check_moe_config
|
||||
|
||||
__all__ = [
|
||||
|
@ -212,6 +213,8 @@ class TransformerOpParallelConfig(_Config):
|
|||
according to the data parallel way. Default: 1.
|
||||
model_parallel (int): The model parallel way. The parameters of dense layers in MultiheadAttention and
|
||||
FeedForward layer will be sliced according to the model parallel way. Default: 1.
|
||||
expert_parallel (int): The expert parallel way. This is effective only when MoE (Mixture of Experts) is applied.
|
||||
This value specifies the number of partitions to split the experts into.
|
||||
pipeline_stage (int): The number of the pipeline stage. Should be a positive value. Default: 1.
|
||||
micro_batch_num (int): The micro size of the batches for the pipeline training. Default: 1.
|
||||
optimizer_shard (bool): Whether to enable optimizer shard. Default False.
|
||||
|
@ -230,7 +233,7 @@ class TransformerOpParallelConfig(_Config):
|
|||
>>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, recompute=recompute_config)
|
||||
"""
|
||||
|
||||
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1,
|
||||
def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1, pipeline_stage=1, micro_batch_num=1,
|
||||
recompute=default_transformer_recompute_config,
|
||||
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
|
||||
self.recompute = recompute
|
||||
|
@ -239,6 +242,8 @@ class TransformerOpParallelConfig(_Config):
|
|||
self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
|
||||
vocab_emb_dp=vocab_emb_dp)
|
||||
self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num)
|
||||
self._moe_config = MoEParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
|
||||
expert_parallel=expert_parallel)
|
||||
|
||||
@property
|
||||
def recompute(self):
|
||||
|
@ -284,6 +289,7 @@ class TransformerOpParallelConfig(_Config):
|
|||
@model_parallel.setter
|
||||
def model_parallel(self, value):
|
||||
self._embed_dp_mp_config.model_parallel = value
|
||||
self._moe_config.model_parallel = value
|
||||
|
||||
@property
|
||||
def data_parallel(self):
|
||||
|
@ -292,6 +298,15 @@ class TransformerOpParallelConfig(_Config):
|
|||
@data_parallel.setter
|
||||
def data_parallel(self, value):
|
||||
self._embed_dp_mp_config.data_parallel = value
|
||||
self._moe_config.data_parallel = value
|
||||
|
||||
@property
|
||||
def expert_parallel(self):
|
||||
return self._moe_config.expert_parallel
|
||||
|
||||
@expert_parallel.setter
|
||||
def expert_parallel(self, value):
|
||||
self._moe_config.expert_parallel = value
|
||||
|
||||
@property
|
||||
def pipeline_stage(self):
|
||||
|
@ -342,6 +357,10 @@ class TransformerOpParallelConfig(_Config):
|
|||
"""
|
||||
return self._embed_dp_mp_config.dp_mp_config
|
||||
|
||||
@property
|
||||
def moe_parallel_config(self):
|
||||
return self._moe_config
|
||||
|
||||
|
||||
default_transformer_config = TransformerOpParallelConfig()
|
||||
default_embedding_parallel_config = EmbeddingOpParallelConfig()
|
||||
|
@ -370,9 +389,9 @@ class FeedForward(Cell):
|
|||
and the first dimension in BatchMatMul indicate expert_num. Default: 1.
|
||||
param_init_type: (dtype.Number): The parameter initialization type. Should be dtype.float32 or dtype.float16.
|
||||
Default: dtype.float32.
|
||||
parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`.
|
||||
Default `default_dpmp_config`, an instance of `OpParallelConfig` with
|
||||
default args.
|
||||
parallel_config (OpParallelConfig, MoEParallelConfig): The config of parallel setting, see `OpParallelConfig` or
|
||||
`MoEParallelConfig`. When MoE is applied, MoEParallelConfig is effective, otherwise OpParallelConfig is
|
||||
effective. Default `default_dpmp_config`, an instance of `OpParallelConfig` with default args.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
|
||||
|
@ -409,7 +428,7 @@ class FeedForward(Cell):
|
|||
hidden_act=_valid_type_checks([str], "FeedForward"),
|
||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"FeedForward"),
|
||||
parallel_config=_valid_type_checks([OpParallelConfig],
|
||||
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig],
|
||||
"FeedForward"))
|
||||
def __init__(self, hidden_size,
|
||||
ffn_hidden_size,
|
||||
|
@ -422,6 +441,10 @@ class FeedForward(Cell):
|
|||
_check_config(parallel_config)
|
||||
dp = parallel_config.data_parallel
|
||||
mp = parallel_config.model_parallel
|
||||
if expert_num > 1:
|
||||
ep = parallel_config.expert_parallel
|
||||
else:
|
||||
ep = 1
|
||||
if ffn_hidden_size % mp != 0:
|
||||
raise ValueError("For 'FeedForward', the class variable 'ffn_hidden_size' must be a multiple of the num of "
|
||||
"model parallel, but got the ffn_hidden_size is {} and the num of model parallel is {}."
|
||||
|
@ -435,20 +458,20 @@ class FeedForward(Cell):
|
|||
"but got the value : {}.".format(dropout_rate))
|
||||
input_size = hidden_size
|
||||
output_size = ffn_hidden_size
|
||||
# Here, 'ep' stands for expert parallel number, which is equal to data parallel number.
|
||||
ep = dp
|
||||
|
||||
# Project to ffn_hidden_size
|
||||
self.mapping = _Linear(in_channels=input_size,
|
||||
out_channels=output_size,
|
||||
activation=hidden_act,
|
||||
transpose_b=False,
|
||||
expert_num=expert_num,
|
||||
outer_batch=dp,
|
||||
param_init_type=param_init_type)
|
||||
|
||||
if expert_num > 1:
|
||||
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)),
|
||||
strategy_bias=((ep, 1, mp), (mp,)),
|
||||
strategy_activation=((ep, 1, mp),))
|
||||
self.mapping.shard(strategy_matmul=((dp, ep, 1, 1), (ep, 1, mp)),
|
||||
strategy_bias=((dp, ep, 1, mp), (mp,)),
|
||||
strategy_activation=((dp, ep, 1, mp),))
|
||||
else:
|
||||
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
|
||||
strategy_bias=((dp, mp), (mp,)),
|
||||
|
@ -458,10 +481,11 @@ class FeedForward(Cell):
|
|||
out_channels=input_size,
|
||||
transpose_b=False,
|
||||
expert_num=expert_num,
|
||||
outer_batch=dp,
|
||||
param_init_type=param_init_type)
|
||||
if expert_num > 1:
|
||||
self.projection.shard(strategy_matmul=((ep, 1, mp), (ep, mp, 1)),
|
||||
strategy_bias=((ep, 1, 1), (1,)))
|
||||
self.projection.shard(strategy_matmul=((dp, ep, 1, mp), (ep, mp, 1)),
|
||||
strategy_bias=((dp, ep, 1, 1), (1,)))
|
||||
else:
|
||||
self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)),
|
||||
strategy_bias=((dp, 1), (1,)))
|
||||
|
@ -470,6 +494,8 @@ class FeedForward(Cell):
|
|||
self.dropout.shard(((dp, 1),))
|
||||
self.dropout_3d = _Dropout(1 - dropout_rate)
|
||||
self.dropout_3d.shard(((dp, 1, 1),))
|
||||
self.dropout_4d = _Dropout(1 - dropout_rate)
|
||||
self.dropout_4d.shard(((dp, ep, 1, 1),))
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -482,8 +508,10 @@ class FeedForward(Cell):
|
|||
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
|
||||
if len(F.shape(output)) == 3:
|
||||
output = self.dropout_3d(output)
|
||||
else:
|
||||
elif len(F.shape(output)) == 2:
|
||||
output = self.dropout(output)
|
||||
else:
|
||||
output = self.dropout_4d(output)
|
||||
return output
|
||||
|
||||
|
||||
|
@ -1164,7 +1192,8 @@ class TransformerEncoderLayer(Cell):
|
|||
pass the single step's input tensor, and loop it. Default False.
|
||||
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with
|
||||
default values. Please see `MoEConfig`.
|
||||
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
|
||||
parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied,
|
||||
MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`,
|
||||
an instance of `OpParallelConfig` with default args.
|
||||
|
||||
Inputs:
|
||||
|
@ -1253,7 +1282,7 @@ class TransformerEncoderLayer(Cell):
|
|||
"TransformerEncoderLayer"),
|
||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerEncoderLayer"),
|
||||
parallel_config=_valid_type_checks([OpParallelConfig],
|
||||
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig],
|
||||
"TransformerEncoderLayer"),
|
||||
use_past=Validator.check_bool)
|
||||
def __init__(self,
|
||||
|
@ -1287,6 +1316,8 @@ class TransformerEncoderLayer(Cell):
|
|||
"by the 'parallel_config.model_parallel', but got the ffn_hidden_size is {} "
|
||||
"and parallel_config. model_parallel is {}."
|
||||
.format(ffn_hidden_size, parallel_config.model_parallel))
|
||||
_check_moe_config(moe_config, parallel_config)
|
||||
self.use_moe = (moe_config.expert_num > 1)
|
||||
self.use_past = use_past
|
||||
self.seq_length = seq_length
|
||||
self.hidden_size = hidden_size
|
||||
|
@ -1295,20 +1326,30 @@ class TransformerEncoderLayer(Cell):
|
|||
self.layernorm1.shard(((parallel_config.data_parallel, 1),))
|
||||
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
||||
self.layernorm2.shard(((parallel_config.data_parallel, 1),))
|
||||
|
||||
self.attention = MultiHeadAttention(batch_size=batch_size,
|
||||
src_seq_length=seq_length,
|
||||
tgt_seq_length=seq_length,
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
param_init_type=param_init_type,
|
||||
use_past=use_past,
|
||||
parallel_config=parallel_config)
|
||||
_check_moe_config(moe_config, parallel_config)
|
||||
self.use_moe = (moe_config.expert_num > 1)
|
||||
if self.use_moe is True:
|
||||
self.attention = MultiHeadAttention(batch_size=batch_size,
|
||||
src_seq_length=seq_length,
|
||||
tgt_seq_length=seq_length,
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
param_init_type=param_init_type,
|
||||
use_past=use_past,
|
||||
parallel_config=parallel_config.dpmp)
|
||||
else:
|
||||
self.attention = MultiHeadAttention(batch_size=batch_size,
|
||||
src_seq_length=seq_length,
|
||||
tgt_seq_length=seq_length,
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
param_init_type=param_init_type,
|
||||
use_past=use_past,
|
||||
parallel_config=parallel_config)
|
||||
if self.use_moe:
|
||||
self.output = MoE(hidden_size=hidden_size,
|
||||
dropout_rate=hidden_dropout_rate,
|
||||
|
@ -1480,7 +1521,8 @@ class TransformerDecoderLayer(Cell):
|
|||
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
|
||||
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with
|
||||
default values. Please see `MoEConfig`.
|
||||
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
|
||||
parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied,
|
||||
MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`,
|
||||
an instance of `OpParallelConfig` with default args.
|
||||
|
||||
Inputs:
|
||||
|
@ -1553,7 +1595,7 @@ class TransformerDecoderLayer(Cell):
|
|||
"TransformerDecoderLayer"),
|
||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerDecoderLayer"),
|
||||
parallel_config=_valid_type_checks([OpParallelConfig],
|
||||
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig],
|
||||
"TransformerDecoderLayer"),
|
||||
use_past=Validator.check_bool)
|
||||
def __init__(self, hidden_size,
|
||||
|
@ -1588,6 +1630,8 @@ class TransformerDecoderLayer(Cell):
|
|||
"divisibled by 'parallel_config.model_parallel', but got the ffn_hidden_size is {} "
|
||||
"and parallel_config.model_parallel is {}."
|
||||
.format(ffn_hidden_size, parallel_config.model_parallel))
|
||||
_check_moe_config(moe_config, parallel_config)
|
||||
self.use_moe = (moe_config.expert_num > 1)
|
||||
if use_past:
|
||||
raise ValueError(f"The {self.cls_name} does not support use_past=True.")
|
||||
self.batch_size = batch_size
|
||||
|
@ -1603,35 +1647,59 @@ class TransformerDecoderLayer(Cell):
|
|||
self.layernorm1.shard(((parallel_config.data_parallel, 1),))
|
||||
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
||||
self.layernorm2.shard(((parallel_config.data_parallel, 1),))
|
||||
|
||||
self.attention = MultiHeadAttention(hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
batch_size=batch_size,
|
||||
src_seq_length=tgt_seq_length,
|
||||
tgt_seq_length=tgt_seq_length,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
use_past=use_past,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
param_init_type=param_init_type,
|
||||
parallel_config=parallel_config)
|
||||
if self.use_moe is True:
|
||||
self.attention = MultiHeadAttention(hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
batch_size=batch_size,
|
||||
src_seq_length=tgt_seq_length,
|
||||
tgt_seq_length=tgt_seq_length,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
use_past=use_past,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
param_init_type=param_init_type,
|
||||
parallel_config=parallel_config.dpmp)
|
||||
else:
|
||||
self.attention = MultiHeadAttention(hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
batch_size=batch_size,
|
||||
src_seq_length=tgt_seq_length,
|
||||
tgt_seq_length=tgt_seq_length,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
use_past=use_past,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
param_init_type=param_init_type,
|
||||
parallel_config=parallel_config)
|
||||
# Cross attention with the output of encoder as memory tensor
|
||||
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
batch_size=batch_size,
|
||||
src_seq_length=tgt_seq_length,
|
||||
tgt_seq_length=src_seq_length,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
use_past=use_past,
|
||||
param_init_type=param_init_type,
|
||||
parallel_config=parallel_config)
|
||||
if self.use_moe is True:
|
||||
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
batch_size=batch_size,
|
||||
src_seq_length=tgt_seq_length,
|
||||
tgt_seq_length=src_seq_length,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
use_past=use_past,
|
||||
param_init_type=param_init_type,
|
||||
parallel_config=parallel_config.dpmp)
|
||||
else:
|
||||
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
batch_size=batch_size,
|
||||
src_seq_length=tgt_seq_length,
|
||||
tgt_seq_length=src_seq_length,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
use_past=use_past,
|
||||
param_init_type=param_init_type,
|
||||
parallel_config=parallel_config)
|
||||
self.cross_attention_layernorm = _LayerNorm((hidden_size,)).to_float(
|
||||
layernorm_compute_type)
|
||||
self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1),))
|
||||
_check_moe_config(moe_config, parallel_config)
|
||||
self.use_moe = (moe_config.expert_num > 1)
|
||||
|
||||
if self.use_moe:
|
||||
self.output = MoE(hidden_size=hidden_size,
|
||||
dropout_rate=hidden_dropout_rate,
|
||||
|
@ -2019,21 +2087,38 @@ class TransformerEncoder(Cell):
|
|||
self.num_layers = num_layers
|
||||
self.blocks = nn.CellList()
|
||||
for i in range(num_layers):
|
||||
block = TransformerEncoderLayer(hidden_size=hidden_size,
|
||||
batch_size=batch_size,
|
||||
ffn_hidden_size=ffn_hidden_size,
|
||||
seq_length=seq_length,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
layernorm_compute_type=layernorm_compute_type,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
num_heads=num_heads,
|
||||
hidden_act=hidden_act,
|
||||
post_layernorm_residual=post_layernorm_residual,
|
||||
param_init_type=param_init_type,
|
||||
use_past=use_past,
|
||||
moe_config=moe_config,
|
||||
parallel_config=parallel_config.dp_mp_config)
|
||||
if self.use_moe is True:
|
||||
block = TransformerEncoderLayer(hidden_size=hidden_size,
|
||||
batch_size=batch_size,
|
||||
ffn_hidden_size=ffn_hidden_size,
|
||||
seq_length=seq_length,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
layernorm_compute_type=layernorm_compute_type,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
num_heads=num_heads,
|
||||
hidden_act=hidden_act,
|
||||
post_layernorm_residual=post_layernorm_residual,
|
||||
param_init_type=param_init_type,
|
||||
use_past=use_past,
|
||||
moe_config=moe_config,
|
||||
parallel_config=parallel_config.moe_parallel_config)
|
||||
else:
|
||||
block = TransformerEncoderLayer(hidden_size=hidden_size,
|
||||
batch_size=batch_size,
|
||||
ffn_hidden_size=ffn_hidden_size,
|
||||
seq_length=seq_length,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
layernorm_compute_type=layernorm_compute_type,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
num_heads=num_heads,
|
||||
hidden_act=hidden_act,
|
||||
post_layernorm_residual=post_layernorm_residual,
|
||||
param_init_type=param_init_type,
|
||||
use_past=use_past,
|
||||
moe_config=moe_config,
|
||||
parallel_config=parallel_config.dp_mp_config)
|
||||
# If the user doesn't pass the fusion function, use the default one
|
||||
if not lambda_func:
|
||||
lambda_func = _get_lambda_func()
|
||||
|
@ -2214,22 +2299,40 @@ class TransformerDecoder(Cell):
|
|||
_check_moe_config(moe_config, parallel_config)
|
||||
self.use_moe = (moe_config.expert_num > 1)
|
||||
for i in range(num_layers):
|
||||
block = TransformerDecoderLayer(hidden_size=hidden_size,
|
||||
batch_size=batch_size,
|
||||
ffn_hidden_size=ffn_hidden_size,
|
||||
src_seq_length=src_seq_length,
|
||||
tgt_seq_length=tgt_seq_length,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
num_heads=num_heads,
|
||||
layernorm_compute_type=layernorm_compute_type,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
hidden_act=hidden_act,
|
||||
use_past=use_past,
|
||||
param_init_type=param_init_type,
|
||||
post_layernorm_residual=post_layernorm_residual,
|
||||
moe_config=moe_config,
|
||||
parallel_config=parallel_config.dp_mp_config)
|
||||
if self.use_moe is True:
|
||||
block = TransformerDecoderLayer(hidden_size=hidden_size,
|
||||
batch_size=batch_size,
|
||||
ffn_hidden_size=ffn_hidden_size,
|
||||
src_seq_length=src_seq_length,
|
||||
tgt_seq_length=tgt_seq_length,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
num_heads=num_heads,
|
||||
layernorm_compute_type=layernorm_compute_type,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
hidden_act=hidden_act,
|
||||
use_past=use_past,
|
||||
param_init_type=param_init_type,
|
||||
post_layernorm_residual=post_layernorm_residual,
|
||||
moe_config=moe_config,
|
||||
parallel_config=parallel_config.moe_parallel_config)
|
||||
else:
|
||||
block = TransformerDecoderLayer(hidden_size=hidden_size,
|
||||
batch_size=batch_size,
|
||||
ffn_hidden_size=ffn_hidden_size,
|
||||
src_seq_length=src_seq_length,
|
||||
tgt_seq_length=tgt_seq_length,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
hidden_dropout_rate=hidden_dropout_rate,
|
||||
num_heads=num_heads,
|
||||
layernorm_compute_type=layernorm_compute_type,
|
||||
softmax_compute_type=softmax_compute_type,
|
||||
hidden_act=hidden_act,
|
||||
use_past=use_past,
|
||||
param_init_type=param_init_type,
|
||||
post_layernorm_residual=post_layernorm_residual,
|
||||
moe_config=moe_config,
|
||||
parallel_config=parallel_config.dp_mp_config)
|
||||
# If the user doesn't pass the fusion function, use the default one
|
||||
if not lambda_func:
|
||||
lambda_func = _get_lambda_func()
|
||||
|
|
|
@ -12,13 +12,16 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.context import set_auto_parallel_context, ParallelMode
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig, CrossEntropyLoss
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell, _VirtualDatasetCell
|
||||
from mindspore.train import Model
|
||||
|
@ -64,6 +67,11 @@ class NetWithLossFiveInputs(nn.Cell):
|
|||
|
||||
|
||||
def test_transformer_model():
|
||||
"""
|
||||
Feature: Test Transformer+MoE, with All2All enabled.
|
||||
Description: 3-dim input.
|
||||
Expectation: Successful graph compilation with All2All included.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=16, global_rank=0,
|
||||
full_batch=True, enable_alltoall=True,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
|
@ -95,6 +103,11 @@ def test_transformer_model():
|
|||
|
||||
|
||||
def test_transformer_model_2d():
|
||||
"""
|
||||
Feature: Test Transformer+MoE, with All2All enabled.
|
||||
Description: 2-dim input.
|
||||
Expectation: Successful graph compilation with All2All included.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=16, global_rank=0,
|
||||
full_batch=True, enable_alltoall=True,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
|
@ -124,3 +137,89 @@ def test_transformer_model_2d():
|
|||
model = Model(net_with_grad)
|
||||
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
class TransformerNet(nn.Cell):
|
||||
"""Transformer with loss"""
|
||||
def __init__(self, en_layer, de_layer, parallel_config):
|
||||
super(TransformerNet, self).__init__()
|
||||
self.network = Transformer(encoder_layers=en_layer,
|
||||
decoder_layers=de_layer,
|
||||
batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
hidden_size=64,
|
||||
num_heads=8,
|
||||
ffn_hidden_size=64,
|
||||
moe_config=moe_config,
|
||||
parallel_config=parallel_config)
|
||||
self.loss = CrossEntropyLoss(parallel_config=config.moe_parallel_config)
|
||||
|
||||
def construct(self, x1, x2, x3, x4, x5, y, mask):
|
||||
predict, _, _ = self.network(x1, x2, x3, x4, x5)
|
||||
predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
|
||||
return self.loss(predict, y, mask)
|
||||
|
||||
|
||||
def moe_with_loss_plus_mutiparallel(local_parallel_config):
|
||||
set_auto_parallel_context(device_num=16, enable_alltoall=True,
|
||||
full_batch=True, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
||||
label = Tensor(np.ones((20,)), mstype.int32)
|
||||
input_mask = Tensor(np.ones((20,)), mstype.float32)
|
||||
|
||||
net = TransformerNet(en_layer=1, de_layer=1, parallel_config=local_parallel_config)
|
||||
net = _VirtualDatasetCell(net)
|
||||
params = net.trainable_params()
|
||||
optimizer = AdamWeightDecay(params)
|
||||
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||
memory_mask, label, input_mask)
|
||||
net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
|
||||
model = Model(net_with_grad)
|
||||
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_moe_expert_parallel1():
|
||||
"""
|
||||
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
|
||||
Description: 3-dim input.
|
||||
Expectation: Successful graph compilation with All2All included.
|
||||
"""
|
||||
local_p_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=4, expert_parallel=2)
|
||||
moe_with_loss_plus_mutiparallel(local_p_config)
|
||||
|
||||
|
||||
def test_moe_expert_parallel2():
|
||||
"""
|
||||
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
|
||||
Description: 3-dim input.
|
||||
Expectation: Successful graph compilation with All2All included.
|
||||
"""
|
||||
local_p_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, expert_parallel=1)
|
||||
moe_with_loss_plus_mutiparallel(local_p_config)
|
||||
|
||||
|
||||
def test_moe_expert_parallel3():
|
||||
"""
|
||||
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
|
||||
Description: 3-dim input.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
local_p_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, expert_parallel=2)
|
||||
moe_with_loss_plus_mutiparallel(local_p_config)
|
||||
|
||||
def test_moe_expert_parallel_exception():
|
||||
"""
|
||||
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
|
||||
Description: data_parallel*model_parallel*expert_parallel > device_num
|
||||
Expectation: Raise ValueError.
|
||||
"""
|
||||
local_p_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, expert_parallel=4)
|
||||
with pytest.raises(ValueError):
|
||||
moe_with_loss_plus_mutiparallel(local_p_config)
|
||||
|
|
Loading…
Reference in New Issue