diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc index 8ab2fc5517f..e6c5f5ebb0c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc @@ -275,6 +275,29 @@ std::vector CumSumInfo::GenerateOpStrategies(int64_t stage_id) { return sp_vector; } +Status CumSumInfo::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_a_tensor_map = inputs_tensor_map_.at(0); + std::vector input_a_group; + if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Create group for input a failed."; + return FAILED; + } + OperatorVector op_for_input_a, op_for_axis; + if (input_a_group.empty()) { + MS_LOG(INFO) << name_ << ": The mirror group is empty."; + return SUCCESS; + } else { + op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); + MS_LOG(INFO) << name_ << ": Create the mirror ops for input a success, groups is " << input_a_group[0].name(); + } + + mirror_ops_.push_back(op_for_input_a); + mirror_ops_.push_back(op_for_axis); + + return SUCCESS; +} + Status CumSumInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } Status ActivationBase::InferDevMatrixShape() { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h index 8da41b868f8..a4c255f162d 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h @@ -130,6 +130,7 @@ class CumSumInfo : public ActivationBase { protected: Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; Status GetAttrs() override; private: diff --git a/mindspore/python/mindspore/nn/transformer/moe.py b/mindspore/python/mindspore/nn/transformer/moe.py index 11ec2e949e1..563325668de 100644 --- a/mindspore/python/mindspore/nn/transformer/moe.py +++ b/mindspore/python/mindspore/nn/transformer/moe.py @@ -153,7 +153,7 @@ class MoE(Cell): input_shape = F.shape(input_tensor) input_tensor = self.reshape(input_tensor, (-1, self.hidden_size)) bs_and_dmodel = self.shape(input_tensor) - tokens_per_device = bs_and_dmodel[0] / self.expert_parallel + tokens_per_device = bs_and_dmodel[0] // self.expert_parallel input_tensor = self.reshape(input_tensor, (self.expert_parallel, tokens_per_device, self.hidden_size)) expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_device, @@ -217,62 +217,6 @@ class MoE(Cell): return combined_output, aux_loss -class _CumSum(Cell): - r""" - A layer used to calculate cumulative summation of a tensor along a dimension. - - Inputs: - - **expert_mask** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, - expert\_dim)`. - - Outputs: - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`. - """ - - def __init__(self, config): - super(_CumSum, self).__init__() - dp = config.data_parallel - self.range = P.Range().shard(((1,),)) - self.reshape = P.Reshape() - self.matmul = P.MatMul().shard(((dp, 1), (1, 1))) - self.shape = P.Shape() - self.cast = P.Cast() - - self.transpose = P.Transpose().shard(((dp, 1, 1),)) - self.transpose2 = P.Transpose().shard(((1, 1),)) - self.transpose3 = P.Transpose().shard(((dp, 1, 1),)) - self.expand = P.ExpandDims().shard(((1,),)) - self.greater = P.Greater().shard(((1, 1), (1, 1))) - - self.start = Tensor(0, mstype.int32) - self.limit = Tensor(0, mstype.int32) - self.delta = Tensor(1, mstype.int32) - self.add = P.Add().shard(((1,), ())) - - def construct(self, expert_mask): - # origin_shape: (expert_parallel, tokens_per_device, self.expert_dim) - origin_shape = self.shape(expert_mask) - tokens_per_device = origin_shape[1] - # expert_mask_trans's shape: (expert_parallel, self.expert_dim, tokens_per_device) - expert_mask_trans = self.transpose(expert_mask, (0, 2, 1)) - # expert_mask_reshaped's shape: (expert_parallel*self.expert_dim, tokens_per_device) - expert_mask_reshaped = self.reshape(expert_mask_trans, (-1, tokens_per_device)) - - one_dim = self.expand(self.range(self.start, self.add(self.limit, tokens_per_device), self.delta), 0) - other_dim = self.transpose2(one_dim, (1, 0)) - # up_tri_matrix's shape: (tokens_per_device, tokens_per_device) - up_tri_matrix = self.greater(one_dim, other_dim) - up_tri_matrix = self.cast(up_tri_matrix, mstype.float32) - - # cum_sum's shape: (expert_parallel*self.expert_dim, tokens_per_device) - cum_sum = self.matmul(expert_mask_reshaped, up_tri_matrix) - # cum_sum's shape: (expert_parallel, self.expert_dim, tokens_per_device) - cum_sum = self.reshape(cum_sum, (origin_shape[0], origin_shape[2], tokens_per_device)) - # cum_sum's shape: (expert_parallel, tokens_per_device, self.expert_dim) - cum_sum = self.transpose3(cum_sum, (0, 2, 1)) - return cum_sum - - class Router(Cell): r""" A router backbone used to calculate logits of each token, which should be cascaded by router implementations @@ -390,7 +334,7 @@ class SwitchRouter(Cell): self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.not_equal = P.NotEqual().shard(((dp, 1, 1, 1), ())) - self.cumsum = _CumSum(config=parallel_config) + self.cumsum = P.CumSum(exclusive=True).shard(((dp, 1, 1),)) self.less = P.Less().shard(((dp, 1, 1), ())) self.reduce_sum = P.ReduceSum(keep_dims=False).shard(((dp, 1, 1),)) self.expand = P.ExpandDims().shard(((dp, 1),)) @@ -413,7 +357,7 @@ class SwitchRouter(Cell): """ Keeping only the tokens that fit within expert_capacity. """ - cumsum = self.cumsum(expert_mask) + cumsum = self.cumsum(expert_mask, 1) # position_in_expert's shape: (expert_parallel, tokens_per_device, self.expert_dim) position_in_expert = self.mul4(cumsum, expert_mask) less_result = self.less(position_in_expert, expert_capacity) @@ -431,7 +375,7 @@ class SwitchRouter(Cell): router_logits_shape = self.shape(router_logits) router_logits = self.reshape(router_logits, (-1, router_logits_shape[-1])) logits_shape = self.shape(router_logits) - tokens_per_device = logits_shape[0] / self.expert_parallel + tokens_per_device = logits_shape[0] // self.expert_parallel expert_capacity = calculate_expert_capacity(1, tokens_per_device, self.capacity_factor, self.expert_dim) router_logits = self.reshape(router_logits, (self.expert_parallel, tokens_per_device, self.expert_dim)) # Currently, lack of gumbel sampler for router_logits. diff --git a/tests/ut/python/parallel/test_parallel_cumsum.py b/tests/ut/python/parallel/test_parallel_cumsum.py index fe5eef2cf72..371af47aa4b 100644 --- a/tests/ut/python/parallel/test_parallel_cumsum.py +++ b/tests/ut/python/parallel/test_parallel_cumsum.py @@ -104,6 +104,33 @@ def test_cumsum_semi2(): compile_net(net, x, y) +def test_cumsum_semi3(): + """ + Feature: CumSum operatorInfo in parallel. + Description: MatMul->CumSum + Expectation: Compile done without error. + """ + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.matmul1 = P.MatMul().shard(((16, 1), (1, 1))) + self.cumsum = P.CumSum().shard(((2, 1),)) + + def construct(self, x, y): + out = self.matmul1(x, y) + out = self.cumsum(out, 1) + return out + + size = 16 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([128, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 64]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + compile_net(net, x, y) + + def test_cumsum_auto(): """ Feature: CumSum operatorInfo in parallel.