diff --git a/docs/api/api_python/transformer/mindspore.nn.transformer.MoEConfig.rst b/docs/api/api_python/transformer/mindspore.nn.transformer.MoEConfig.rst index 742df75ebfb..40732815b4e 100644 --- a/docs/api/api_python/transformer/mindspore.nn.transformer.MoEConfig.rst +++ b/docs/api/api_python/transformer/mindspore.nn.transformer.MoEConfig.rst @@ -1,4 +1,4 @@ -.. py:class:: mindspore.nn.transformer.MoEConfig(expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1, expert_group_size=None) +.. py:class:: mindspore.nn.transformer.MoEConfig(expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1, expert_group_size=None, group_wise_a2a=True, comp_comm_parallel=False, comp_comm_parallel_degree=2) MoE (Mixture of Expert)的配置。 @@ -8,3 +8,6 @@ - **aux_loss_factor** (float) - 表示负载均衡损失(由路由器产生)的平衡系数。相乘的结果会加到总损失函数中。此系数的值小于1.0。默认值:0.05。 - **num_experts_chosen** (int) - 表示每个标识选择的专家数量,其值小于等于专家数量。默认值:1。 - **expert_group_size** (int) - 表示每个数据并行组收到的词语(token)数量。默认值:None。该参数只在自动并行且非策略传播模式下起作用。 + - **group_wise_a2a** (bool) -表示否是使能group-wise alltoall通信,group-wise alltoall通信可以把部分节点间通信转化为节点内通信从而减低通信时间。默认值:False。该参数只有在模型并行数大于1且数据并行数等于专家并行数生效。 + - **comp_comm_parallel** (bool) 是否使能ffn计算和通信并行,可以通过拆分重叠计算和通信来减少纯通信时间。默认值:False。 + - **comp_comm_parallel_degree** (bool) 计算和通信的拆分数量。数字越大重叠越多,但会消耗更多的显存。默认值:2。该参数只在comp_comm_parallel为True下生效。 diff --git a/mindspore/python/mindspore/nn/transformer/moe.py b/mindspore/python/mindspore/nn/transformer/moe.py index 505e5215c0d..0f6d123ccdc 100644 --- a/mindspore/python/mindspore/nn/transformer/moe.py +++ b/mindspore/python/mindspore/nn/transformer/moe.py @@ -52,21 +52,34 @@ class MoEConfig: than expert_num. Default: 1. expert_group_size (int): The number of tokens in each data parallel group. Default: None. This parameter is effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION. + group_wise_a2a (bool): Whether to enable group-wise alltoall communication, which can reduce communication + time by converting part of inter communication into intra communication. Default: False. This parameter + is effective only when model parallel > 1 and data_parallel equal to expert parallel. + comp_comm_parallel (bool): Whether to enable ffn compute and communication parallel, which can reduce pure + communicattion time by splitting and overlapping compute and communication. Default: False. + comp_comm_parallel_degree (int): The split number of compute and communication. The larger the numbers, + the more overlap there will be but will consume more memory. Default: 2. This parameter is effective + only when comp_comm_parallel enable. + Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindspore.nn.transformer import MoEConfig >>> moe_config = MoEConfig(expert_num=4, capacity_factor=5.0, aux_loss_factor=0.05, num_experts_chosen=1, - ... expert_group_size=64) + ... expert_group_size=64, group_wise_a2a=True, comp_comm_parallel=False, + comp_comm_parallel_degree=2) """ - def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, - num_experts_chosen=1, expert_group_size=None): + def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1, + expert_group_size=None, group_wise_a2a=False, comp_comm_parallel=False, comp_comm_parallel_degree=2): Validator.check_positive_int(expert_num, "expert_num") Validator.check_positive_float(capacity_factor, "capacity_factor") Validator.check_positive_float(aux_loss_factor, "aux_loss_factor") Validator.check_positive_int(num_experts_chosen, "num_experts_chosen") + Validator.check_bool(group_wise_a2a, "group_wise_a2a") + Validator.check_bool(comp_comm_parallel, "comp_comm_parallel") + Validator.check_positive_int(comp_comm_parallel_degree, "comp_comm_parallel_degree") if expert_group_size is not None: Validator.check_positive_int(expert_group_size, "expert_group_size") if capacity_factor < 1.0: @@ -83,6 +96,9 @@ class MoEConfig: self.aux_loss_factor = aux_loss_factor self.num_experts_chosen = num_experts_chosen self.expert_group_size = expert_group_size + self.group_wise_a2a = group_wise_a2a + self.comp_comm_parallel = comp_comm_parallel + self.comp_comm_parallel_degree = comp_comm_parallel_degree default_moe_config = MoEConfig() @@ -165,6 +181,12 @@ class MoE(Cell): self.dp_group = parallel_config.data_parallel self.dp = parallel_config.data_parallel self.ep = parallel_config.expert_parallel + self.mp = parallel_config.model_parallel + self.comp_comm_parallel = moe_config.comp_comm_parallel + self.comp_comm_parallel_degree = moe_config.comp_comm_parallel_degree + self.group_wise_a2a = moe_config.group_wise_a2a + if not (self.mp > 1 and self.dp == self.ep): + self.group_wise_a2a = False from mindspore.nn.transformer import FeedForward self.ffn = FeedForward(hidden_size=hidden_size, @@ -178,15 +200,23 @@ class MoE(Cell): self.reshape = P.Reshape() self.shape = P.Shape() self.transpose_2dim = P.Transpose().shard(((self.dp, 1),)) - self.transpose_2dim_ep = P.Transpose().shard(((self.ep, 1),)) self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),)) - self.transpose_4dim_ep = P.Transpose().shard(((self.ep, 1, 1, 1),)) + self.transpose_4dim = P.Transpose().shard(((1, self.dp, 1, 1),)) + self.transpose_4dim_dp = P.Transpose().shard(((1, 1, self.dp, 1),)) self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.mul = P.Mul() self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None, training=True, parallel_config=parallel_config) self.cast = P.Cast() + self.concat = P.Concat(3).shard(tuple((self.dp, 1, 1, 1) for _ in range(self.comp_comm_parallel_degree))) + self.concat_dp = P.Concat(2).shard(((1, self.dp, 1, 1), (1, self.dp, 1, 1))) + self.split = P.Split(axis=2, output_num=self.comp_comm_parallel_degree).shard(((1, self.dp, 1, 1),)) + self.stride_slice = P.StridedSlice().shard(((self.dp, 1, 1, 1),)) + self.stride_slice_dp = P.StridedSlice().shard(((1, self.dp, 1, 1),)) + self.stride_slice_ep = P.StridedSlice().shard(((self.ep, 1, 1, 1),)) + self.stride_slice_dp_mp = P.StridedSlice().shard(((1, self.dp, self.mp, 1),)) + self.stride_slice_ep_mp = P.StridedSlice().shard(((self.ep, 1, self.mp, 1),)) else: self.hidden_size = hidden_size self.expert_dim = moe_config.expert_num @@ -196,6 +226,12 @@ class MoE(Cell): self.dp_group = parallel_config.data_parallel self.dp = parallel_config.data_parallel self.ep = parallel_config.expert_parallel + self.mp = parallel_config.model_parallel + self.comp_comm_parallel = moe_config.comp_comm_parallel + self.comp_comm_parallel_degree = moe_config.comp_comm_parallel_degree + self.group_wise_a2a = moe_config.group_wise_a2a + if not (self.mp > 1 and self.dp == self.ep): + self.group_wise_a2a = False from mindspore.nn.transformer import FeedForward self.ffn = FeedForward(hidden_size=hidden_size, @@ -208,15 +244,110 @@ class MoE(Cell): self.reshape = P.Reshape() self.shape = P.Shape() self.transpose_2dim = P.Transpose().shard(((self.dp, 1),)) - self.transpose_2dim_ep = P.Transpose().shard(((self.ep, 1),)) self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),)) - self.transpose_4dim_ep = P.Transpose().shard(((self.ep, 1, 1, 1),)) + self.transpose_4dim = P.Transpose().shard(((1, self.dp, 1, 1),)) + self.transpose_4dim_dp = P.Transpose().shard(((1, 1, self.dp, 1),)) self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.mul = P.Mul().shard(((), ())) self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None, training=True, parallel_config=parallel_config) self.cast = P.Cast() + self.concat = P.Concat(3).shard(tuple((self.dp, 1, 1, 1) for _ in range(self.comp_comm_parallel_degree))) + self.concat_dp = P.Concat(2).shard(((1, self.dp, 1, 1), (1, self.dp, 1, 1))) + self.split = P.Split(axis=2, output_num=self.comp_comm_parallel_degree).shard(((1, self.dp, 1, 1),)) + self.stride_slice = P.StridedSlice().shard(((self.dp, 1, 1, 1),)) + self.stride_slice_dp = P.StridedSlice().shard(((1, self.dp, 1, 1),)) + self.stride_slice_ep = P.StridedSlice().shard(((self.ep, 1, 1, 1),)) + self.stride_slice_dp_mp = P.StridedSlice().shard(((1, self.dp, self.mp, 1),)) + self.stride_slice_ep_mp = P.StridedSlice().shard(((self.ep, 1, self.mp, 1),)) + + def ffn_infer(self, expert_input, capacity): + """ + Computing the FFN. + """ + pad_size = 0 + if self.group_wise_a2a: + # If capacity can't div by mp, pad for mp shard. + if capacity%self.mp != 0: + pad_size = self.mp-(capacity%self.mp) + if pad_size != 0: + capacity += pad_size + pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0), + (self.expert_dim, self.dp_group, pad_size, self.hidden_size), + (1, 1, 1, 1)) + expert_input = self.concat_dp((expert_input, pad_tensor)) + # capacity shard by mp + expert_input = self.stride_slice_dp_mp(expert_input, (0, 0, 0, 0), + (self.expert_dim, self.dp_group, capacity, self.hidden_size), + (1, 1, 1, 1)) + # group-wise alltoall + expert_input = self.stride_slice_ep_mp(expert_input, (0, 0, 0, 0), + (self.expert_dim, self.dp_group, capacity, self.hidden_size), + (1, 1, 1, 1)) + # allgather + expert_input = self.stride_slice_ep(expert_input, (0, 0, 0, 0), + (self.expert_dim, self.dp_group, capacity, self.hidden_size), + (1, 1, 1, 1)) + + expert_input = self.reshape(expert_input, (self.expert_dim * self.dp_group * capacity, + self.hidden_size)) + # expert_output's shape: (self.expert_dim, self.dp_group*expert_capacity, self.hidden_size) + expert_output = self.ffn(expert_input) + expert_output = self.reshape(expert_output, (self.expert_dim, self.dp_group, + capacity, self.hidden_size)) + + if self.group_wise_a2a: + # capacity shard by mp + expert_output = self.stride_slice_ep_mp(expert_output, (0, 0, 0, 0), + (self.expert_dim, self.dp_group, capacity, self.hidden_size), + (1, 1, 1, 1)) + # group-wise alltoall + expert_output = self.stride_slice_dp_mp(expert_output, (0, 0, 0, 0), + (self.expert_dim, self.dp_group, capacity, self.hidden_size), + (1, 1, 1, 1)) + # allgather + expert_output = self.stride_slice_dp(expert_output, (0, 0, 0, 0), + (self.expert_dim, self.dp_group, capacity, self.hidden_size), + (1, 1, 1, 1)) + # Slice capacity back to org shape. + if pad_size != 0: + capacity -= pad_size + expert_output = self.stride_slice_dp(expert_output, (0, 0, 0, 0), + (self.expert_dim, self.dp_group, capacity, self.hidden_size), + (1, 1, 1, 1)) + # expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity) + expert_output = self.transpose_4dim(expert_output, (1, 3, 0, 2)) + return expert_output + + def ffn_parallel_infer(self, expert_input, capacity): + """ + Split and overlap FFN compute and communication. + """ + # Pad capacity for comp_comm_parallel_degree split. + pad_size = 0 + if capacity%self.comp_comm_parallel_degree != 0: + pad_size = self.comp_comm_parallel_degree-(capacity%self.comp_comm_parallel_degree) + capacity += pad_size + pad_tensor = self.stride_slice_dp(expert_input, (0, 0, 0, 0), + (self.expert_dim, self.dp_group, pad_size, self.hidden_size), + (1, 1, 1, 1)) + expert_input = self.concat_dp((expert_input, pad_tensor)) + + sub_capacity = capacity // self.comp_comm_parallel_degree + output_list = [] + for sub_expert_input in self.split(expert_input): + sub_expert_output = self.ffn_infer(sub_expert_input, sub_capacity) + output_list.append(sub_expert_output) + expert_output = self.concat(output_list) + + # Slice capacity back to org shape. + if pad_size != 0: + capacity -= pad_size + expert_output = self.stride_slice(expert_output, (0, 0, 0, 0), + (self.dp_group, self.hidden_size, self.expert_dim, capacity), + (1, 1, 1, 1)) + return expert_output def construct(self, input_tensor): input_shape = F.shape(input_tensor) @@ -248,25 +379,14 @@ class MoE(Cell): expert_input = self.reshape(expert_input, (self.expert_dim, expert_capacity, self.dp_group, self.hidden_size)) # expert_input's shape: (self.expert_dim, self.dp_group, expert_capacity, self.hidden_size) - expert_input = self.transpose_4dim_ep(expert_input, (0, 2, 1, 3)) - expert_input = self.reshape(expert_input, (self.expert_dim * self.dp_group * expert_capacity, - self.hidden_size)) + expert_input = self.transpose_4dim_dp(expert_input, (0, 2, 1, 3)) - # expert_output's shape: (self.expert_dim, self.dp_group*expert_capacity, self.hidden_size) - expert_output = self.ffn(expert_input) - expert_output = self.reshape(expert_output, (self.expert_dim, self.dp_group, - expert_capacity, self.hidden_size)) - # The following five ops are to implement transpose(expert_output, (1, 3, 0, 2)), for that a single transpose - # has bad performance - expert_output = self.reshape(expert_output, (self.expert_dim, - self.dp_group * expert_capacity * self.hidden_size)) - expert_output = self.transpose_2dim_ep(expert_output, (1, 0)) - expert_output = self.reshape(expert_output, (self.dp_group, expert_capacity, - self.hidden_size * self.expert_dim)) - expert_output = self.transpose_3dim(expert_output, (0, 2, 1)) # expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity) - expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size, self.expert_dim, - expert_capacity)) + if self.comp_comm_parallel: + expert_output = self.ffn_parallel_infer(expert_input, expert_capacity) + else: + expert_output = self.ffn_infer(expert_input, expert_capacity) + expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size, self.expert_dim * expert_capacity)) combine_tensor = self.reshape(combine_tensor, (self.dp_group, tokens_per_group,