1) fix the exact division in moe;

2) changing CumSum from composition to a single Operator;
3) add InferMirrorOps for CumSumInfo.
This commit is contained in:
Xiaoda Zhang 2021-12-30 16:33:02 +08:00
parent 7c241bbaf5
commit 6d8320fa66
4 changed files with 55 additions and 60 deletions

View File

@ -275,6 +275,29 @@ std::vector<StrategyPtr> CumSumInfo::GenerateOpStrategies(int64_t stage_id) {
return sp_vector;
}
Status CumSumInfo::InferMirrorOps() {
mirror_ops_.clear();
Shape input_a_tensor_map = inputs_tensor_map_.at(0);
std::vector<Group> input_a_group;
if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group for input a failed.";
return FAILED;
}
OperatorVector op_for_input_a, op_for_axis;
if (input_a_group.empty()) {
MS_LOG(INFO) << name_ << ": The mirror group is empty.";
return SUCCESS;
} else {
op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum());
MS_LOG(INFO) << name_ << ": Create the mirror ops for input a success, groups is " << input_a_group[0].name();
}
mirror_ops_.push_back(op_for_input_a);
mirror_ops_.push_back(op_for_axis);
return SUCCESS;
}
Status CumSumInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
Status ActivationBase::InferDevMatrixShape() {

View File

@ -130,6 +130,7 @@ class CumSumInfo : public ActivationBase {
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status GetAttrs() override;
private:

View File

@ -153,7 +153,7 @@ class MoE(Cell):
input_shape = F.shape(input_tensor)
input_tensor = self.reshape(input_tensor, (-1, self.hidden_size))
bs_and_dmodel = self.shape(input_tensor)
tokens_per_device = bs_and_dmodel[0] / self.expert_parallel
tokens_per_device = bs_and_dmodel[0] // self.expert_parallel
input_tensor = self.reshape(input_tensor, (self.expert_parallel, tokens_per_device, self.hidden_size))
expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_device,
@ -217,62 +217,6 @@ class MoE(Cell):
return combined_output, aux_loss
class _CumSum(Cell):
r"""
A layer used to calculate cumulative summation of a tensor along a dimension.
Inputs:
- **expert_mask** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device,
expert\_dim)`.
Outputs:
Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`.
"""
def __init__(self, config):
super(_CumSum, self).__init__()
dp = config.data_parallel
self.range = P.Range().shard(((1,),))
self.reshape = P.Reshape()
self.matmul = P.MatMul().shard(((dp, 1), (1, 1)))
self.shape = P.Shape()
self.cast = P.Cast()
self.transpose = P.Transpose().shard(((dp, 1, 1),))
self.transpose2 = P.Transpose().shard(((1, 1),))
self.transpose3 = P.Transpose().shard(((dp, 1, 1),))
self.expand = P.ExpandDims().shard(((1,),))
self.greater = P.Greater().shard(((1, 1), (1, 1)))
self.start = Tensor(0, mstype.int32)
self.limit = Tensor(0, mstype.int32)
self.delta = Tensor(1, mstype.int32)
self.add = P.Add().shard(((1,), ()))
def construct(self, expert_mask):
# origin_shape: (expert_parallel, tokens_per_device, self.expert_dim)
origin_shape = self.shape(expert_mask)
tokens_per_device = origin_shape[1]
# expert_mask_trans's shape: (expert_parallel, self.expert_dim, tokens_per_device)
expert_mask_trans = self.transpose(expert_mask, (0, 2, 1))
# expert_mask_reshaped's shape: (expert_parallel*self.expert_dim, tokens_per_device)
expert_mask_reshaped = self.reshape(expert_mask_trans, (-1, tokens_per_device))
one_dim = self.expand(self.range(self.start, self.add(self.limit, tokens_per_device), self.delta), 0)
other_dim = self.transpose2(one_dim, (1, 0))
# up_tri_matrix's shape: (tokens_per_device, tokens_per_device)
up_tri_matrix = self.greater(one_dim, other_dim)
up_tri_matrix = self.cast(up_tri_matrix, mstype.float32)
# cum_sum's shape: (expert_parallel*self.expert_dim, tokens_per_device)
cum_sum = self.matmul(expert_mask_reshaped, up_tri_matrix)
# cum_sum's shape: (expert_parallel, self.expert_dim, tokens_per_device)
cum_sum = self.reshape(cum_sum, (origin_shape[0], origin_shape[2], tokens_per_device))
# cum_sum's shape: (expert_parallel, tokens_per_device, self.expert_dim)
cum_sum = self.transpose3(cum_sum, (0, 2, 1))
return cum_sum
class Router(Cell):
r"""
A router backbone used to calculate logits of each token, which should be cascaded by router implementations
@ -390,7 +334,7 @@ class SwitchRouter(Cell):
self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
self.not_equal = P.NotEqual().shard(((dp, 1, 1, 1), ()))
self.cumsum = _CumSum(config=parallel_config)
self.cumsum = P.CumSum(exclusive=True).shard(((dp, 1, 1),))
self.less = P.Less().shard(((dp, 1, 1), ()))
self.reduce_sum = P.ReduceSum(keep_dims=False).shard(((dp, 1, 1),))
self.expand = P.ExpandDims().shard(((dp, 1),))
@ -413,7 +357,7 @@ class SwitchRouter(Cell):
"""
Keeping only the tokens that fit within expert_capacity.
"""
cumsum = self.cumsum(expert_mask)
cumsum = self.cumsum(expert_mask, 1)
# position_in_expert's shape: (expert_parallel, tokens_per_device, self.expert_dim)
position_in_expert = self.mul4(cumsum, expert_mask)
less_result = self.less(position_in_expert, expert_capacity)
@ -431,7 +375,7 @@ class SwitchRouter(Cell):
router_logits_shape = self.shape(router_logits)
router_logits = self.reshape(router_logits, (-1, router_logits_shape[-1]))
logits_shape = self.shape(router_logits)
tokens_per_device = logits_shape[0] / self.expert_parallel
tokens_per_device = logits_shape[0] // self.expert_parallel
expert_capacity = calculate_expert_capacity(1, tokens_per_device, self.capacity_factor, self.expert_dim)
router_logits = self.reshape(router_logits, (self.expert_parallel, tokens_per_device, self.expert_dim))
# Currently, lack of gumbel sampler for router_logits.

View File

@ -104,6 +104,33 @@ def test_cumsum_semi2():
compile_net(net, x, y)
def test_cumsum_semi3():
"""
Feature: CumSum operatorInfo in parallel.
Description: MatMul->CumSum
Expectation: Compile done without error.
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.matmul1 = P.MatMul().shard(((16, 1), (1, 1)))
self.cumsum = P.CumSum().shard(((2, 1),))
def construct(self, x, y):
out = self.matmul1(x, y)
out = self.cumsum(out, 1)
return out
size = 16
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
compile_net(net, x, y)
def test_cumsum_auto():
"""
Feature: CumSum operatorInfo in parallel.