forked from mindspore-Ecosystem/mindspore
Optimize MOE infer latency
This commit is contained in:
parent
6277a023ed
commit
8e19972274
|
@ -1,4 +1,4 @@
|
||||||
.. py:class:: mindspore.nn.transformer.MoEConfig(expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1, expert_group_size=None)
|
.. py:class:: mindspore.nn.transformer.MoEConfig(expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1, expert_group_size=None, group_wise_a2a=True, comp_comm_parallel=False, comp_comm_parallel_degree=2)
|
||||||
|
|
||||||
MoE (Mixture of Expert)的配置。
|
MoE (Mixture of Expert)的配置。
|
||||||
|
|
||||||
|
@ -8,3 +8,6 @@
|
||||||
- **aux_loss_factor** (float) - 表示负载均衡损失(由路由器产生)的平衡系数。相乘的结果会加到总损失函数中。此系数的值小于1.0。默认值:0.05。
|
- **aux_loss_factor** (float) - 表示负载均衡损失(由路由器产生)的平衡系数。相乘的结果会加到总损失函数中。此系数的值小于1.0。默认值:0.05。
|
||||||
- **num_experts_chosen** (int) - 表示每个标识选择的专家数量,其值小于等于专家数量。默认值:1。
|
- **num_experts_chosen** (int) - 表示每个标识选择的专家数量,其值小于等于专家数量。默认值:1。
|
||||||
- **expert_group_size** (int) - 表示每个数据并行组收到的词语(token)数量。默认值:None。该参数只在自动并行且非策略传播模式下起作用。
|
- **expert_group_size** (int) - 表示每个数据并行组收到的词语(token)数量。默认值:None。该参数只在自动并行且非策略传播模式下起作用。
|
||||||
|
- **group_wise_a2a** (bool) -表示否是使能group-wise alltoall通信,group-wise alltoall通信可以把部分节点间通信转化为节点内通信从而减低通信时间。默认值:False。该参数只有在模型并行数大于1且数据并行数等于专家并行数生效。
|
||||||
|
- **comp_comm_parallel** (bool) 是否使能ffn计算和通信并行,可以通过拆分重叠计算和通信来减少纯通信时间。默认值:False。
|
||||||
|
- **comp_comm_parallel_degree** (bool) 计算和通信的拆分数量。数字越大重叠越多,但会消耗更多的显存。默认值:2。该参数只在comp_comm_parallel为True下生效。
|
||||||
|
|
|
@ -52,21 +52,34 @@ class MoEConfig:
|
||||||
than expert_num. Default: 1.
|
than expert_num. Default: 1.
|
||||||
expert_group_size (int): The number of tokens in each data parallel group. Default: None. This parameter is
|
expert_group_size (int): The number of tokens in each data parallel group. Default: None. This parameter is
|
||||||
effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.
|
effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.
|
||||||
|
group_wise_a2a (bool): Whether to enable group-wise alltoall communication, which can reduce communication
|
||||||
|
time by converting part of inter communication into intra communication. Default: False. This parameter
|
||||||
|
is effective only when model parallel > 1 and data_parallel equal to expert parallel.
|
||||||
|
comp_comm_parallel (bool): Whether to enable ffn compute and communication parallel, which can reduce pure
|
||||||
|
communicattion time by splitting and overlapping compute and communication. Default: False.
|
||||||
|
comp_comm_parallel_degree (int): The split number of compute and communication. The larger the numbers,
|
||||||
|
the more overlap there will be but will consume more memory. Default: 2. This parameter is effective
|
||||||
|
only when comp_comm_parallel enable.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from mindspore.nn.transformer import MoEConfig
|
>>> from mindspore.nn.transformer import MoEConfig
|
||||||
>>> moe_config = MoEConfig(expert_num=4, capacity_factor=5.0, aux_loss_factor=0.05, num_experts_chosen=1,
|
>>> moe_config = MoEConfig(expert_num=4, capacity_factor=5.0, aux_loss_factor=0.05, num_experts_chosen=1,
|
||||||
... expert_group_size=64)
|
... expert_group_size=64, group_wise_a2a=True, comp_comm_parallel=False,
|
||||||
|
comp_comm_parallel_degree=2)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05,
|
def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1,
|
||||||
num_experts_chosen=1, expert_group_size=None):
|
expert_group_size=None, group_wise_a2a=False, comp_comm_parallel=False, comp_comm_parallel_degree=2):
|
||||||
Validator.check_positive_int(expert_num, "expert_num")
|
Validator.check_positive_int(expert_num, "expert_num")
|
||||||
Validator.check_positive_float(capacity_factor, "capacity_factor")
|
Validator.check_positive_float(capacity_factor, "capacity_factor")
|
||||||
Validator.check_positive_float(aux_loss_factor, "aux_loss_factor")
|
Validator.check_positive_float(aux_loss_factor, "aux_loss_factor")
|
||||||
Validator.check_positive_int(num_experts_chosen, "num_experts_chosen")
|
Validator.check_positive_int(num_experts_chosen, "num_experts_chosen")
|
||||||
|
Validator.check_bool(group_wise_a2a, "group_wise_a2a")
|
||||||
|
Validator.check_bool(comp_comm_parallel, "comp_comm_parallel")
|
||||||
|
Validator.check_positive_int(comp_comm_parallel_degree, "comp_comm_parallel_degree")
|
||||||
if expert_group_size is not None:
|
if expert_group_size is not None:
|
||||||
Validator.check_positive_int(expert_group_size, "expert_group_size")
|
Validator.check_positive_int(expert_group_size, "expert_group_size")
|
||||||
if capacity_factor < 1.0:
|
if capacity_factor < 1.0:
|
||||||
|
@ -83,6 +96,9 @@ class MoEConfig:
|
||||||
self.aux_loss_factor = aux_loss_factor
|
self.aux_loss_factor = aux_loss_factor
|
||||||
self.num_experts_chosen = num_experts_chosen
|
self.num_experts_chosen = num_experts_chosen
|
||||||
self.expert_group_size = expert_group_size
|
self.expert_group_size = expert_group_size
|
||||||
|
self.group_wise_a2a = group_wise_a2a
|
||||||
|
self.comp_comm_parallel = comp_comm_parallel
|
||||||
|
self.comp_comm_parallel_degree = comp_comm_parallel_degree
|
||||||
|
|
||||||
|
|
||||||
default_moe_config = MoEConfig()
|
default_moe_config = MoEConfig()
|
||||||
|
@ -165,6 +181,12 @@ class MoE(Cell):
|
||||||
self.dp_group = parallel_config.data_parallel
|
self.dp_group = parallel_config.data_parallel
|
||||||
self.dp = parallel_config.data_parallel
|
self.dp = parallel_config.data_parallel
|
||||||
self.ep = parallel_config.expert_parallel
|
self.ep = parallel_config.expert_parallel
|
||||||
|
self.mp = parallel_config.model_parallel
|
||||||
|
self.comp_comm_parallel = moe_config.comp_comm_parallel
|
||||||
|
self.comp_comm_parallel_degree = moe_config.comp_comm_parallel_degree
|
||||||
|
self.group_wise_a2a = moe_config.group_wise_a2a
|
||||||
|
if not (self.mp > 1 and self.dp == self.ep):
|
||||||
|
self.group_wise_a2a = False
|
||||||
from mindspore.nn.transformer import FeedForward
|
from mindspore.nn.transformer import FeedForward
|
||||||
|
|
||||||
self.ffn = FeedForward(hidden_size=hidden_size,
|
self.ffn = FeedForward(hidden_size=hidden_size,
|
||||||
|
@ -178,15 +200,23 @@ class MoE(Cell):
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.shape = P.Shape()
|
self.shape = P.Shape()
|
||||||
self.transpose_2dim = P.Transpose().shard(((self.dp, 1),))
|
self.transpose_2dim = P.Transpose().shard(((self.dp, 1),))
|
||||||
self.transpose_2dim_ep = P.Transpose().shard(((self.ep, 1),))
|
|
||||||
self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),))
|
self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),))
|
||||||
self.transpose_4dim_ep = P.Transpose().shard(((self.ep, 1, 1, 1),))
|
self.transpose_4dim = P.Transpose().shard(((1, self.dp, 1, 1),))
|
||||||
|
self.transpose_4dim_dp = P.Transpose().shard(((1, 1, self.dp, 1),))
|
||||||
self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
|
self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
|
||||||
self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
|
self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
|
||||||
self.mul = P.Mul()
|
self.mul = P.Mul()
|
||||||
self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None,
|
self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None,
|
||||||
training=True, parallel_config=parallel_config)
|
training=True, parallel_config=parallel_config)
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
self.concat = P.Concat(3).shard(tuple((self.dp, 1, 1, 1) for _ in range(self.comp_comm_parallel_degree)))
|
||||||
|
self.concat_dp = P.Concat(2).shard(((1, self.dp, 1, 1), (1, self.dp, 1, 1)))
|
||||||
|
self.split = P.Split(axis=2, output_num=self.comp_comm_parallel_degree).shard(((1, self.dp, 1, 1),))
|
||||||
|
self.stride_slice = P.StridedSlice().shard(((self.dp, 1, 1, 1),))
|
||||||
|
self.stride_slice_dp = P.StridedSlice().shard(((1, self.dp, 1, 1),))
|
||||||
|
self.stride_slice_ep = P.StridedSlice().shard(((self.ep, 1, 1, 1),))
|
||||||
|
self.stride_slice_dp_mp = P.StridedSlice().shard(((1, self.dp, self.mp, 1),))
|
||||||
|
self.stride_slice_ep_mp = P.StridedSlice().shard(((self.ep, 1, self.mp, 1),))
|
||||||
else:
|
else:
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.expert_dim = moe_config.expert_num
|
self.expert_dim = moe_config.expert_num
|
||||||
|
@ -196,6 +226,12 @@ class MoE(Cell):
|
||||||
self.dp_group = parallel_config.data_parallel
|
self.dp_group = parallel_config.data_parallel
|
||||||
self.dp = parallel_config.data_parallel
|
self.dp = parallel_config.data_parallel
|
||||||
self.ep = parallel_config.expert_parallel
|
self.ep = parallel_config.expert_parallel
|
||||||
|
self.mp = parallel_config.model_parallel
|
||||||
|
self.comp_comm_parallel = moe_config.comp_comm_parallel
|
||||||
|
self.comp_comm_parallel_degree = moe_config.comp_comm_parallel_degree
|
||||||
|
self.group_wise_a2a = moe_config.group_wise_a2a
|
||||||
|
if not (self.mp > 1 and self.dp == self.ep):
|
||||||
|
self.group_wise_a2a = False
|
||||||
from mindspore.nn.transformer import FeedForward
|
from mindspore.nn.transformer import FeedForward
|
||||||
|
|
||||||
self.ffn = FeedForward(hidden_size=hidden_size,
|
self.ffn = FeedForward(hidden_size=hidden_size,
|
||||||
|
@ -208,15 +244,110 @@ class MoE(Cell):
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.shape = P.Shape()
|
self.shape = P.Shape()
|
||||||
self.transpose_2dim = P.Transpose().shard(((self.dp, 1),))
|
self.transpose_2dim = P.Transpose().shard(((self.dp, 1),))
|
||||||
self.transpose_2dim_ep = P.Transpose().shard(((self.ep, 1),))
|
|
||||||
self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),))
|
self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),))
|
||||||
self.transpose_4dim_ep = P.Transpose().shard(((self.ep, 1, 1, 1),))
|
self.transpose_4dim = P.Transpose().shard(((1, self.dp, 1, 1),))
|
||||||
|
self.transpose_4dim_dp = P.Transpose().shard(((1, 1, self.dp, 1),))
|
||||||
self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
|
self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
|
||||||
self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
|
self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
|
||||||
self.mul = P.Mul().shard(((), ()))
|
self.mul = P.Mul().shard(((), ()))
|
||||||
self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None,
|
self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None,
|
||||||
training=True, parallel_config=parallel_config)
|
training=True, parallel_config=parallel_config)
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
self.concat = P.Concat(3).shard(tuple((self.dp, 1, 1, 1) for _ in range(self.comp_comm_parallel_degree)))
|
||||||
|
self.concat_dp = P.Concat(2).shard(((1, self.dp, 1, 1), (1, self.dp, 1, 1)))
|
||||||
|
self.split = P.Split(axis=2, output_num=self.comp_comm_parallel_degree).shard(((1, self.dp, 1, 1),))
|
||||||
|
self.stride_slice = P.StridedSlice().shard(((self.dp, 1, 1, 1),))
|
||||||
|
self.stride_slice_dp = P.StridedSlice().shard(((1, self.dp, 1, 1),))
|
||||||
|
self.stride_slice_ep = P.StridedSlice().shard(((self.ep, 1, 1, 1),))
|
||||||
|
self.stride_slice_dp_mp = P.StridedSlice().shard(((1, self.dp, self.mp, 1),))
|
||||||
|
self.stride_slice_ep_mp = P.StridedSlice().shard(((self.ep, 1, self.mp, 1),))
|
||||||
|
|
||||||
|
def ffn_infer(self, expert_input, capacity):
|
||||||
|
"""
|
||||||
|
Computing the FFN.
|
||||||
|
"""
|
||||||
|
pad_size = 0
|
||||||
|
if self.group_wise_a2a:
|
||||||
|
# If capacity can't div by mp, pad for mp shard.
|
||||||
|
if capacity%self.mp != 0:
|
||||||
|
pad_size = self.mp-(capacity%self.mp)
|
||||||
|
if pad_size != 0:
|
||||||
|
capacity += pad_size
|
||||||
|
pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
|
||||||
|
(self.expert_dim, self.dp_group, pad_size, self.hidden_size),
|
||||||
|
(1, 1, 1, 1))
|
||||||
|
expert_input = self.concat_dp((expert_input, pad_tensor))
|
||||||
|
# capacity shard by mp
|
||||||
|
expert_input = self.stride_slice_dp_mp(expert_input, (0, 0, 0, 0),
|
||||||
|
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
|
||||||
|
(1, 1, 1, 1))
|
||||||
|
# group-wise alltoall
|
||||||
|
expert_input = self.stride_slice_ep_mp(expert_input, (0, 0, 0, 0),
|
||||||
|
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
|
||||||
|
(1, 1, 1, 1))
|
||||||
|
# allgather
|
||||||
|
expert_input = self.stride_slice_ep(expert_input, (0, 0, 0, 0),
|
||||||
|
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
|
||||||
|
(1, 1, 1, 1))
|
||||||
|
|
||||||
|
expert_input = self.reshape(expert_input, (self.expert_dim * self.dp_group * capacity,
|
||||||
|
self.hidden_size))
|
||||||
|
# expert_output's shape: (self.expert_dim, self.dp_group*expert_capacity, self.hidden_size)
|
||||||
|
expert_output = self.ffn(expert_input)
|
||||||
|
expert_output = self.reshape(expert_output, (self.expert_dim, self.dp_group,
|
||||||
|
capacity, self.hidden_size))
|
||||||
|
|
||||||
|
if self.group_wise_a2a:
|
||||||
|
# capacity shard by mp
|
||||||
|
expert_output = self.stride_slice_ep_mp(expert_output, (0, 0, 0, 0),
|
||||||
|
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
|
||||||
|
(1, 1, 1, 1))
|
||||||
|
# group-wise alltoall
|
||||||
|
expert_output = self.stride_slice_dp_mp(expert_output, (0, 0, 0, 0),
|
||||||
|
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
|
||||||
|
(1, 1, 1, 1))
|
||||||
|
# allgather
|
||||||
|
expert_output = self.stride_slice_dp(expert_output, (0, 0, 0, 0),
|
||||||
|
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
|
||||||
|
(1, 1, 1, 1))
|
||||||
|
# Slice capacity back to org shape.
|
||||||
|
if pad_size != 0:
|
||||||
|
capacity -= pad_size
|
||||||
|
expert_output = self.stride_slice_dp(expert_output, (0, 0, 0, 0),
|
||||||
|
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
|
||||||
|
(1, 1, 1, 1))
|
||||||
|
# expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity)
|
||||||
|
expert_output = self.transpose_4dim(expert_output, (1, 3, 0, 2))
|
||||||
|
return expert_output
|
||||||
|
|
||||||
|
def ffn_parallel_infer(self, expert_input, capacity):
|
||||||
|
"""
|
||||||
|
Split and overlap FFN compute and communication.
|
||||||
|
"""
|
||||||
|
# Pad capacity for comp_comm_parallel_degree split.
|
||||||
|
pad_size = 0
|
||||||
|
if capacity%self.comp_comm_parallel_degree != 0:
|
||||||
|
pad_size = self.comp_comm_parallel_degree-(capacity%self.comp_comm_parallel_degree)
|
||||||
|
capacity += pad_size
|
||||||
|
pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
|
||||||
|
(self.expert_dim, self.dp_group, pad_size, self.hidden_size),
|
||||||
|
(1, 1, 1, 1))
|
||||||
|
expert_input = self.concat_dp((expert_input, pad_tensor))
|
||||||
|
|
||||||
|
sub_capacity = capacity // self.comp_comm_parallel_degree
|
||||||
|
output_list = []
|
||||||
|
for sub_expert_input in self.split(expert_input):
|
||||||
|
sub_expert_output = self.ffn_infer(sub_expert_input, sub_capacity)
|
||||||
|
output_list.append(sub_expert_output)
|
||||||
|
expert_output = self.concat(output_list)
|
||||||
|
|
||||||
|
# Slice capacity back to org shape.
|
||||||
|
if pad_size != 0:
|
||||||
|
capacity -= pad_size
|
||||||
|
expert_output = self.stride_slice(expert_output, (0, 0, 0, 0),
|
||||||
|
(self.dp_group, self.hidden_size, self.expert_dim, capacity),
|
||||||
|
(1, 1, 1, 1))
|
||||||
|
return expert_output
|
||||||
|
|
||||||
def construct(self, input_tensor):
|
def construct(self, input_tensor):
|
||||||
input_shape = F.shape(input_tensor)
|
input_shape = F.shape(input_tensor)
|
||||||
|
@ -248,25 +379,14 @@ class MoE(Cell):
|
||||||
expert_input = self.reshape(expert_input, (self.expert_dim, expert_capacity, self.dp_group,
|
expert_input = self.reshape(expert_input, (self.expert_dim, expert_capacity, self.dp_group,
|
||||||
self.hidden_size))
|
self.hidden_size))
|
||||||
# expert_input's shape: (self.expert_dim, self.dp_group, expert_capacity, self.hidden_size)
|
# expert_input's shape: (self.expert_dim, self.dp_group, expert_capacity, self.hidden_size)
|
||||||
expert_input = self.transpose_4dim_ep(expert_input, (0, 2, 1, 3))
|
expert_input = self.transpose_4dim_dp(expert_input, (0, 2, 1, 3))
|
||||||
expert_input = self.reshape(expert_input, (self.expert_dim * self.dp_group * expert_capacity,
|
|
||||||
self.hidden_size))
|
|
||||||
|
|
||||||
# expert_output's shape: (self.expert_dim, self.dp_group*expert_capacity, self.hidden_size)
|
|
||||||
expert_output = self.ffn(expert_input)
|
|
||||||
expert_output = self.reshape(expert_output, (self.expert_dim, self.dp_group,
|
|
||||||
expert_capacity, self.hidden_size))
|
|
||||||
# The following five ops are to implement transpose(expert_output, (1, 3, 0, 2)), for that a single transpose
|
|
||||||
# has bad performance
|
|
||||||
expert_output = self.reshape(expert_output, (self.expert_dim,
|
|
||||||
self.dp_group * expert_capacity * self.hidden_size))
|
|
||||||
expert_output = self.transpose_2dim_ep(expert_output, (1, 0))
|
|
||||||
expert_output = self.reshape(expert_output, (self.dp_group, expert_capacity,
|
|
||||||
self.hidden_size * self.expert_dim))
|
|
||||||
expert_output = self.transpose_3dim(expert_output, (0, 2, 1))
|
|
||||||
# expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity)
|
# expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity)
|
||||||
expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size, self.expert_dim,
|
if self.comp_comm_parallel:
|
||||||
expert_capacity))
|
expert_output = self.ffn_parallel_infer(expert_input, expert_capacity)
|
||||||
|
else:
|
||||||
|
expert_output = self.ffn_infer(expert_input, expert_capacity)
|
||||||
|
|
||||||
expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size,
|
expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size,
|
||||||
self.expert_dim * expert_capacity))
|
self.expert_dim * expert_capacity))
|
||||||
combine_tensor = self.reshape(combine_tensor, (self.dp_group, tokens_per_group,
|
combine_tensor = self.reshape(combine_tensor, (self.dp_group, tokens_per_group,
|
||||||
|
|
Loading…
Reference in New Issue