Optimize MOE infer latency

This commit is contained in:
z00617246 2022-09-23 16:40:10 +08:00
parent 6277a023ed
commit 8e19972274
2 changed files with 148 additions and 25 deletions

View File

@ -1,4 +1,4 @@
.. py:class:: mindspore.nn.transformer.MoEConfig(expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1, expert_group_size=None) .. py:class:: mindspore.nn.transformer.MoEConfig(expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1, expert_group_size=None, group_wise_a2a=True, comp_comm_parallel=False, comp_comm_parallel_degree=2)
MoE (Mixture of Expert)的配置。 MoE (Mixture of Expert)的配置。
@ -8,3 +8,6 @@
- **aux_loss_factor** (float) - 表示负载均衡损失由路由器产生的平衡系数。相乘的结果会加到总损失函数中。此系数的值小于1.0。默认值0.05。 - **aux_loss_factor** (float) - 表示负载均衡损失由路由器产生的平衡系数。相乘的结果会加到总损失函数中。此系数的值小于1.0。默认值0.05。
- **num_experts_chosen** (int) - 表示每个标识选择的专家数量其值小于等于专家数量。默认值1。 - **num_experts_chosen** (int) - 表示每个标识选择的专家数量其值小于等于专家数量。默认值1。
- **expert_group_size** (int) - 表示每个数据并行组收到的词语token数量。默认值None。该参数只在自动并行且非策略传播模式下起作用。 - **expert_group_size** (int) - 表示每个数据并行组收到的词语token数量。默认值None。该参数只在自动并行且非策略传播模式下起作用。
- **group_wise_a2a** (bool) -表示否是使能group-wise alltoall通信group-wise alltoall通信可以把部分节点间通信转化为节点内通信从而减低通信时间。默认值False。该参数只有在模型并行数大于1且数据并行数等于专家并行数生效。
- **comp_comm_parallel** (bool) 是否使能ffn计算和通信并行可以通过拆分重叠计算和通信来减少纯通信时间。默认值False。
- **comp_comm_parallel_degree** (bool) 计算和通信的拆分数量。数字越大重叠越多但会消耗更多的显存。默认值2。该参数只在comp_comm_parallel为True下生效。

View File

@ -52,21 +52,34 @@ class MoEConfig:
than expert_num. Default: 1. than expert_num. Default: 1.
expert_group_size (int): The number of tokens in each data parallel group. Default: None. This parameter is expert_group_size (int): The number of tokens in each data parallel group. Default: None. This parameter is
effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION. effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION.
group_wise_a2a (bool): Whether to enable group-wise alltoall communication, which can reduce communication
time by converting part of inter communication into intra communication. Default: False. This parameter
is effective only when model parallel > 1 and data_parallel equal to expert parallel.
comp_comm_parallel (bool): Whether to enable ffn compute and communication parallel, which can reduce pure
communicattion time by splitting and overlapping compute and communication. Default: False.
comp_comm_parallel_degree (int): The split number of compute and communication. The larger the numbers,
the more overlap there will be but will consume more memory. Default: 2. This parameter is effective
only when comp_comm_parallel enable.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``Ascend`` ``GPU``
Examples: Examples:
>>> from mindspore.nn.transformer import MoEConfig >>> from mindspore.nn.transformer import MoEConfig
>>> moe_config = MoEConfig(expert_num=4, capacity_factor=5.0, aux_loss_factor=0.05, num_experts_chosen=1, >>> moe_config = MoEConfig(expert_num=4, capacity_factor=5.0, aux_loss_factor=0.05, num_experts_chosen=1,
... expert_group_size=64) ... expert_group_size=64, group_wise_a2a=True, comp_comm_parallel=False,
comp_comm_parallel_degree=2)
""" """
def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1,
num_experts_chosen=1, expert_group_size=None): expert_group_size=None, group_wise_a2a=False, comp_comm_parallel=False, comp_comm_parallel_degree=2):
Validator.check_positive_int(expert_num, "expert_num") Validator.check_positive_int(expert_num, "expert_num")
Validator.check_positive_float(capacity_factor, "capacity_factor") Validator.check_positive_float(capacity_factor, "capacity_factor")
Validator.check_positive_float(aux_loss_factor, "aux_loss_factor") Validator.check_positive_float(aux_loss_factor, "aux_loss_factor")
Validator.check_positive_int(num_experts_chosen, "num_experts_chosen") Validator.check_positive_int(num_experts_chosen, "num_experts_chosen")
Validator.check_bool(group_wise_a2a, "group_wise_a2a")
Validator.check_bool(comp_comm_parallel, "comp_comm_parallel")
Validator.check_positive_int(comp_comm_parallel_degree, "comp_comm_parallel_degree")
if expert_group_size is not None: if expert_group_size is not None:
Validator.check_positive_int(expert_group_size, "expert_group_size") Validator.check_positive_int(expert_group_size, "expert_group_size")
if capacity_factor < 1.0: if capacity_factor < 1.0:
@ -83,6 +96,9 @@ class MoEConfig:
self.aux_loss_factor = aux_loss_factor self.aux_loss_factor = aux_loss_factor
self.num_experts_chosen = num_experts_chosen self.num_experts_chosen = num_experts_chosen
self.expert_group_size = expert_group_size self.expert_group_size = expert_group_size
self.group_wise_a2a = group_wise_a2a
self.comp_comm_parallel = comp_comm_parallel
self.comp_comm_parallel_degree = comp_comm_parallel_degree
default_moe_config = MoEConfig() default_moe_config = MoEConfig()
@ -165,6 +181,12 @@ class MoE(Cell):
self.dp_group = parallel_config.data_parallel self.dp_group = parallel_config.data_parallel
self.dp = parallel_config.data_parallel self.dp = parallel_config.data_parallel
self.ep = parallel_config.expert_parallel self.ep = parallel_config.expert_parallel
self.mp = parallel_config.model_parallel
self.comp_comm_parallel = moe_config.comp_comm_parallel
self.comp_comm_parallel_degree = moe_config.comp_comm_parallel_degree
self.group_wise_a2a = moe_config.group_wise_a2a
if not (self.mp > 1 and self.dp == self.ep):
self.group_wise_a2a = False
from mindspore.nn.transformer import FeedForward from mindspore.nn.transformer import FeedForward
self.ffn = FeedForward(hidden_size=hidden_size, self.ffn = FeedForward(hidden_size=hidden_size,
@ -178,15 +200,23 @@ class MoE(Cell):
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.shape = P.Shape() self.shape = P.Shape()
self.transpose_2dim = P.Transpose().shard(((self.dp, 1),)) self.transpose_2dim = P.Transpose().shard(((self.dp, 1),))
self.transpose_2dim_ep = P.Transpose().shard(((self.ep, 1),))
self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),)) self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),))
self.transpose_4dim_ep = P.Transpose().shard(((self.ep, 1, 1, 1),)) self.transpose_4dim = P.Transpose().shard(((1, self.dp, 1, 1),))
self.transpose_4dim_dp = P.Transpose().shard(((1, 1, self.dp, 1),))
self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
self.mul = P.Mul() self.mul = P.Mul()
self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None, self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None,
training=True, parallel_config=parallel_config) training=True, parallel_config=parallel_config)
self.cast = P.Cast() self.cast = P.Cast()
self.concat = P.Concat(3).shard(tuple((self.dp, 1, 1, 1) for _ in range(self.comp_comm_parallel_degree)))
self.concat_dp = P.Concat(2).shard(((1, self.dp, 1, 1), (1, self.dp, 1, 1)))
self.split = P.Split(axis=2, output_num=self.comp_comm_parallel_degree).shard(((1, self.dp, 1, 1),))
self.stride_slice = P.StridedSlice().shard(((self.dp, 1, 1, 1),))
self.stride_slice_dp = P.StridedSlice().shard(((1, self.dp, 1, 1),))
self.stride_slice_ep = P.StridedSlice().shard(((self.ep, 1, 1, 1),))
self.stride_slice_dp_mp = P.StridedSlice().shard(((1, self.dp, self.mp, 1),))
self.stride_slice_ep_mp = P.StridedSlice().shard(((self.ep, 1, self.mp, 1),))
else: else:
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.expert_dim = moe_config.expert_num self.expert_dim = moe_config.expert_num
@ -196,6 +226,12 @@ class MoE(Cell):
self.dp_group = parallel_config.data_parallel self.dp_group = parallel_config.data_parallel
self.dp = parallel_config.data_parallel self.dp = parallel_config.data_parallel
self.ep = parallel_config.expert_parallel self.ep = parallel_config.expert_parallel
self.mp = parallel_config.model_parallel
self.comp_comm_parallel = moe_config.comp_comm_parallel
self.comp_comm_parallel_degree = moe_config.comp_comm_parallel_degree
self.group_wise_a2a = moe_config.group_wise_a2a
if not (self.mp > 1 and self.dp == self.ep):
self.group_wise_a2a = False
from mindspore.nn.transformer import FeedForward from mindspore.nn.transformer import FeedForward
self.ffn = FeedForward(hidden_size=hidden_size, self.ffn = FeedForward(hidden_size=hidden_size,
@ -208,15 +244,110 @@ class MoE(Cell):
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.shape = P.Shape() self.shape = P.Shape()
self.transpose_2dim = P.Transpose().shard(((self.dp, 1),)) self.transpose_2dim = P.Transpose().shard(((self.dp, 1),))
self.transpose_2dim_ep = P.Transpose().shard(((self.ep, 1),))
self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),)) self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),))
self.transpose_4dim_ep = P.Transpose().shard(((self.ep, 1, 1, 1),)) self.transpose_4dim = P.Transpose().shard(((1, self.dp, 1, 1),))
self.transpose_4dim_dp = P.Transpose().shard(((1, 1, self.dp, 1),))
self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
self.mul = P.Mul().shard(((), ())) self.mul = P.Mul().shard(((), ()))
self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None, self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None,
training=True, parallel_config=parallel_config) training=True, parallel_config=parallel_config)
self.cast = P.Cast() self.cast = P.Cast()
self.concat = P.Concat(3).shard(tuple((self.dp, 1, 1, 1) for _ in range(self.comp_comm_parallel_degree)))
self.concat_dp = P.Concat(2).shard(((1, self.dp, 1, 1), (1, self.dp, 1, 1)))
self.split = P.Split(axis=2, output_num=self.comp_comm_parallel_degree).shard(((1, self.dp, 1, 1),))
self.stride_slice = P.StridedSlice().shard(((self.dp, 1, 1, 1),))
self.stride_slice_dp = P.StridedSlice().shard(((1, self.dp, 1, 1),))
self.stride_slice_ep = P.StridedSlice().shard(((self.ep, 1, 1, 1),))
self.stride_slice_dp_mp = P.StridedSlice().shard(((1, self.dp, self.mp, 1),))
self.stride_slice_ep_mp = P.StridedSlice().shard(((self.ep, 1, self.mp, 1),))
def ffn_infer(self, expert_input, capacity):
"""
Computing the FFN.
"""
pad_size = 0
if self.group_wise_a2a:
# If capacity can't div by mp, pad for mp shard.
if capacity%self.mp != 0:
pad_size = self.mp-(capacity%self.mp)
if pad_size != 0:
capacity += pad_size
pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
(self.expert_dim, self.dp_group, pad_size, self.hidden_size),
(1, 1, 1, 1))
expert_input = self.concat_dp((expert_input, pad_tensor))
# capacity shard by mp
expert_input = self.stride_slice_dp_mp(expert_input, (0, 0, 0, 0),
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
(1, 1, 1, 1))
# group-wise alltoall
expert_input = self.stride_slice_ep_mp(expert_input, (0, 0, 0, 0),
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
(1, 1, 1, 1))
# allgather
expert_input = self.stride_slice_ep(expert_input, (0, 0, 0, 0),
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
(1, 1, 1, 1))
expert_input = self.reshape(expert_input, (self.expert_dim * self.dp_group * capacity,
self.hidden_size))
# expert_output's shape: (self.expert_dim, self.dp_group*expert_capacity, self.hidden_size)
expert_output = self.ffn(expert_input)
expert_output = self.reshape(expert_output, (self.expert_dim, self.dp_group,
capacity, self.hidden_size))
if self.group_wise_a2a:
# capacity shard by mp
expert_output = self.stride_slice_ep_mp(expert_output, (0, 0, 0, 0),
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
(1, 1, 1, 1))
# group-wise alltoall
expert_output = self.stride_slice_dp_mp(expert_output, (0, 0, 0, 0),
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
(1, 1, 1, 1))
# allgather
expert_output = self.stride_slice_dp(expert_output, (0, 0, 0, 0),
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
(1, 1, 1, 1))
# Slice capacity back to org shape.
if pad_size != 0:
capacity -= pad_size
expert_output = self.stride_slice_dp(expert_output, (0, 0, 0, 0),
(self.expert_dim, self.dp_group, capacity, self.hidden_size),
(1, 1, 1, 1))
# expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity)
expert_output = self.transpose_4dim(expert_output, (1, 3, 0, 2))
return expert_output
def ffn_parallel_infer(self, expert_input, capacity):
"""
Split and overlap FFN compute and communication.
"""
# Pad capacity for comp_comm_parallel_degree split.
pad_size = 0
if capacity%self.comp_comm_parallel_degree != 0:
pad_size = self.comp_comm_parallel_degree-(capacity%self.comp_comm_parallel_degree)
capacity += pad_size
pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0),
(self.expert_dim, self.dp_group, pad_size, self.hidden_size),
(1, 1, 1, 1))
expert_input = self.concat_dp((expert_input, pad_tensor))
sub_capacity = capacity // self.comp_comm_parallel_degree
output_list = []
for sub_expert_input in self.split(expert_input):
sub_expert_output = self.ffn_infer(sub_expert_input, sub_capacity)
output_list.append(sub_expert_output)
expert_output = self.concat(output_list)
# Slice capacity back to org shape.
if pad_size != 0:
capacity -= pad_size
expert_output = self.stride_slice(expert_output, (0, 0, 0, 0),
(self.dp_group, self.hidden_size, self.expert_dim, capacity),
(1, 1, 1, 1))
return expert_output
def construct(self, input_tensor): def construct(self, input_tensor):
input_shape = F.shape(input_tensor) input_shape = F.shape(input_tensor)
@ -248,25 +379,14 @@ class MoE(Cell):
expert_input = self.reshape(expert_input, (self.expert_dim, expert_capacity, self.dp_group, expert_input = self.reshape(expert_input, (self.expert_dim, expert_capacity, self.dp_group,
self.hidden_size)) self.hidden_size))
# expert_input's shape: (self.expert_dim, self.dp_group, expert_capacity, self.hidden_size) # expert_input's shape: (self.expert_dim, self.dp_group, expert_capacity, self.hidden_size)
expert_input = self.transpose_4dim_ep(expert_input, (0, 2, 1, 3)) expert_input = self.transpose_4dim_dp(expert_input, (0, 2, 1, 3))
expert_input = self.reshape(expert_input, (self.expert_dim * self.dp_group * expert_capacity,
self.hidden_size))
# expert_output's shape: (self.expert_dim, self.dp_group*expert_capacity, self.hidden_size)
expert_output = self.ffn(expert_input)
expert_output = self.reshape(expert_output, (self.expert_dim, self.dp_group,
expert_capacity, self.hidden_size))
# The following five ops are to implement transpose(expert_output, (1, 3, 0, 2)), for that a single transpose
# has bad performance
expert_output = self.reshape(expert_output, (self.expert_dim,
self.dp_group * expert_capacity * self.hidden_size))
expert_output = self.transpose_2dim_ep(expert_output, (1, 0))
expert_output = self.reshape(expert_output, (self.dp_group, expert_capacity,
self.hidden_size * self.expert_dim))
expert_output = self.transpose_3dim(expert_output, (0, 2, 1))
# expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity) # expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity)
expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size, self.expert_dim, if self.comp_comm_parallel:
expert_capacity)) expert_output = self.ffn_parallel_infer(expert_input, expert_capacity)
else:
expert_output = self.ffn_infer(expert_input, expert_capacity)
expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size, expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size,
self.expert_dim * expert_capacity)) self.expert_dim * expert_capacity))
combine_tensor = self.reshape(combine_tensor, (self.dp_group, tokens_per_group, combine_tensor = self.reshape(combine_tensor, (self.dp_group, tokens_per_group,