!35839 fix bug of moe

Merge pull request !35839 from bichaoyang/master
This commit is contained in:
i-robot 2022-06-17 08:15:25 +00:00 committed by Gitee
commit e902c7c731
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 241 additions and 37 deletions

View File

@ -303,6 +303,7 @@ class Router(Cell):
parallel_config=None):
super(Router, self).__init__()
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
dp = parallel_config.data_parallel
self.d_model = d_model
self.expert_dim = moe_config.expert_num
self.capacity_factor = moe_config.capacity_factor
@ -314,6 +315,7 @@ class Router(Cell):
self.noise = Tensor(np.random.uniform(1 - self.noisy_epsilon, 1 + self.noisy_epsilon, (d_model,)))
self.dense = Dense(in_channels=self.d_model, out_channels=self.expert_dim, has_bias=False)
self.dense.matmul.shard(((dp, 1), (1, 1)))
self.mul = P.Mul()
self.cast = P.Cast()
@ -430,6 +432,8 @@ class TopkRouter(Cell):
self.reduce_sum_keep2 = P.ReduceSum(keep_dims=True)
self.expand = P.ExpandDims()
self.expand2 = P.ExpandDims()
self.add_scala = P.Add()
self.init_loss = Tensor(0.0, mstype.float32)
else:
dp = parallel_config.data_parallel
self.d_model = d_model
@ -479,6 +483,8 @@ class TopkRouter(Cell):
self.reduce_sum_keep2 = P.ReduceSum(keep_dims=True).shard(((dp, 1, 1, 1),))
self.expand = P.ExpandDims().shard(((dp, 1),))
self.expand2 = P.ExpandDims().shard(((dp, 1, 1),))
self.add_scala = P.Add().shard(((), ()))
self.init_loss = Tensor(0.0, mstype.float32)
def _auxiliary_loss(self, expert_mask, router_prob):
"""
@ -525,7 +531,7 @@ class TopkRouter(Cell):
accum_expert_mask = 0
accum_expert_gate = 0
loss = 0
loss = self.init_loss
mask_count = 0
accum_combine_tensor = 0
# Probabilities for each token of what expert is should be sent to
@ -542,7 +548,7 @@ class TopkRouter(Cell):
router_prob_normal = self.div1(router_prob, self.add1(self.reduce_sum_keep(router_prob, -1), 1e-9))
# the balance loss is computed at each routing step
loss += self._auxiliary_loss(expert_mask, router_prob_normal)
loss = self.add_scala(loss, self._auxiliary_loss(expert_mask, router_prob_normal))
output = self._maskout_overflowed_tokens(expert_mask, expert_capacity, expert_gate,
mask_count, expert_chosen_index)

View File

@ -22,6 +22,7 @@ from mindspore.context import set_auto_parallel_context, ParallelMode
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig, CrossEntropyLoss
from mindspore.parallel import set_algo_parameters
from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell, _VirtualDatasetCell
from mindspore.train import Model
@ -80,15 +81,7 @@ class NetWithLossMoe(nn.Cell):
return self.add(predict, moe_loss)
def test_transformer_model():
"""
Feature: Test Transformer+MoE, with All2All enabled.
Description: 3-dim input.
Expectation: Successful graph compilation with All2All included.
"""
set_auto_parallel_context(device_num=16, global_rank=0,
full_batch=True, enable_alltoall=True,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
def run_transformer_model():
net = Transformer(encoder_layers=1,
decoder_layers=1,
batch_size=2,
@ -116,15 +109,32 @@ def test_transformer_model():
model.train(1, dataset, dataset_sink_mode=False)
def test_transformer_model_2d():
def test_transformer_model_semi():
"""
Feature: Test Transformer+MoE, with All2All enabled.
Description: 2-dim input.
Description: 3-dim input.
Expectation: Successful graph compilation with All2All included.
"""
set_auto_parallel_context(device_num=16, global_rank=0,
full_batch=True, enable_alltoall=True,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
run_transformer_model()
def test_transformer_model_sp():
"""
Feature: Test Transformer+MoE, with All2All enabled and sharding propagation.
Description: 3-dim input.
Expectation: Successful graph compilation with All2All included.
"""
set_auto_parallel_context(device_num=16, global_rank=0, search_mode="sharding_propagation",
full_batch=True, enable_alltoall=True,
parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
run_transformer_model()
def run_transformer_model_2d():
net = Transformer(encoder_layers=1,
decoder_layers=1,
batch_size=2,
@ -153,6 +163,31 @@ def test_transformer_model_2d():
model.train(1, dataset, dataset_sink_mode=False)
def test_transformer_model_2d_semi():
"""
Feature: Test Transformer+MoE, with All2All enabled.
Description: 2-dim input.
Expectation: Successful graph compilation with All2All included.
"""
set_auto_parallel_context(device_num=16, global_rank=0,
full_batch=True, enable_alltoall=True,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
run_transformer_model_2d()
def test_transformer_model_2d_sp():
"""
Feature: Test Transformer+MoE, with All2All enabled and sharding propagation.
Description: 2-dim input.
Expectation: Successful graph compilation with All2All included.
"""
set_auto_parallel_context(device_num=16, global_rank=0, search_mode="sharding_propagation",
full_batch=True, enable_alltoall=True,
parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
run_transformer_model_2d()
class TransformerNet(nn.Cell):
"""Transformer with loss"""
def __init__(self, en_layer, de_layer, parallel_config):
@ -176,9 +211,6 @@ class TransformerNet(nn.Cell):
def moe_with_loss_plus_mutiparallel(local_parallel_config):
set_auto_parallel_context(device_num=16, enable_alltoall=True,
full_batch=True, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
@ -205,6 +237,8 @@ def test_moe_expert_parallel1():
Description: 3-dim input.
Expectation: Successful graph compilation with All2All included.
"""
set_auto_parallel_context(device_num=16, enable_alltoall=True,
full_batch=True, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
local_p_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=4, expert_parallel=2)
moe_with_loss_plus_mutiparallel(local_p_config)
@ -215,21 +249,52 @@ def test_moe_expert_parallel2():
Description: 3-dim input.
Expectation: Successful graph compilation with All2All included.
"""
set_auto_parallel_context(device_num=16, enable_alltoall=True,
full_batch=True, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
local_p_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, expert_parallel=1)
moe_with_loss_plus_mutiparallel(local_p_config)
def test_moe_expert_parallel3():
"""
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled
and sharding propagation.
Description: 3-dim input.
Expectation: Successful graph compilation with All2All included.
"""
set_auto_parallel_context(device_num=16, enable_alltoall=True, search_mode="sharding_propagation",
full_batch=True, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
local_p_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=4, expert_parallel=2)
moe_with_loss_plus_mutiparallel(local_p_config)
def test_moe_expert_parallel4():
"""
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled
and sharding propagation.
Description: 3-dim input.
Expectation: Successful graph compilation with All2All included.
"""
set_auto_parallel_context(device_num=16, enable_alltoall=True, search_mode="sharding_propagation",
full_batch=True, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
local_p_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, expert_parallel=1)
moe_with_loss_plus_mutiparallel(local_p_config)
def test_moe_expert_parallel_exception1():
"""
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
Description: 3-dim input.
Expectation: Successful graph compilation.
Expectation: Raise ValueError.
"""
local_p_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, expert_parallel=2)
with pytest.raises(ValueError):
moe_with_loss_plus_mutiparallel(local_p_config)
def test_moe_expert_parallel_exception():
def test_moe_expert_parallel_exception2():
"""
Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
Description: data_parallel*model_parallel*expert_parallel > device_num

View File

@ -166,10 +166,7 @@ def run_total_transformer_model_head(e_layer,
model.train(1, dataset, dataset_sink_mode=False)
def test_transformer_model():
set_auto_parallel_context(device_num=8, global_rank=0,
full_batch=True,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
def run_transformer_model():
net = Transformer(encoder_layers=1,
decoder_layers=2,
batch_size=2,
@ -197,10 +194,32 @@ def test_transformer_model():
model.train(1, dataset, dataset_sink_mode=False)
def test_transformer_model_2d_inputs():
def test_transformer_model_semi():
"""
Feature: Test Transformer.
Description: 3-dim input.
Expectation: Successful graph compilation.
"""
set_auto_parallel_context(device_num=8, global_rank=0,
full_batch=True,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
run_transformer_model()
def test_transformer_model_sp():
"""
Feature: Test Transformer with sharding propagation.
Description: 3-dim input.
Expectation: Successful graph compilation.
"""
set_auto_parallel_context(device_num=8, global_rank=0,
full_batch=True, search_mode="sharding_propagation",
parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
run_transformer_model()
def run_transformer_model_2d_inputs():
net = Transformer(encoder_layers=1,
decoder_layers=2,
batch_size=2,
@ -228,6 +247,31 @@ def test_transformer_model_2d_inputs():
model.train(1, dataset, dataset_sink_mode=False)
def test_transformer_model_2d_semi():
"""
Feature: Test Transformer.
Description: 2-dim input.
Expectation: Successful graph compilation.
"""
set_auto_parallel_context(device_num=8, global_rank=0,
full_batch=True,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
run_transformer_model_2d_inputs()
def test_transformer_model_2d_sp():
"""
Feature: Test Transformer with sharding propagation.
Description: 2-dim input.
Expectation: Successful graph compilation.
"""
set_auto_parallel_context(device_num=8, global_rank=0,
full_batch=True, search_mode="sharding_propagation",
parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
run_transformer_model_2d_inputs()
class TestTransformerEmbeddingHead:
def __init__(self):
self.output_path = None
@ -255,19 +299,12 @@ class TestTransformerEmbeddingHead:
appear_count += 1
assert appear_count == target_count
def test_pipeline_with_embedding(self):
"""
Feature: Test Transformer with embedding as shared
Description: When do pipeline training and applied optimzier shard, the embedding which is model parallel will
raise the shape error. This test cast is ensure there is no error raised.
Expectation: The number of AssignAdd is not as expected.
"""
def run_pipeline_with_embedding(self):
bs = 16
pp = 2
context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=pp,
full_batch=True,
enable_parallel_optimizer=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
cf = TransformerOpParallelConfig(data_parallel=1, model_parallel=4, pipeline_stage=pp, vocab_emb_dp=False)
pipeline_net = TransformerEncoderNet(batch_size=bs // pp,
en_layer=2, de_layer=0, parallel_config=cf)
@ -284,11 +321,29 @@ class TestTransformerEmbeddingHead:
run_network_function(dataset, pipeline_cell_net)
self.virtual_assign_add_from_ir(pattern=r'AssignAdd(', target_count=35)
def test_pipeline_with_embedding_semi(self):
"""
Feature: Test Transformer with embedding as shared
Description: When do pipeline training and applied optimzier shard, the embedding which is model parallel will
raise the shape error. This test cast is ensure there is no error raised.
Expectation: The number of AssignAdd is not as expected.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
self.run_pipeline_with_embedding()
def test_transformer_model_int64_inputs():
set_auto_parallel_context(device_num=8, global_rank=0,
full_batch=True,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
def test_pipeline_with_embedding_sp(self):
"""
Feature: Test Transformer with embedding as shared, using sharding propagation.
Description: When do pipeline training and applied optimzier shard, the embedding which is model parallel will
raise the shape error. This test cast is ensure there is no error raised.
Expectation: The number of AssignAdd is not as expected.
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", search_mode="sharding_propagation")
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
self.run_pipeline_with_embedding()
def run_transformer_model_int64_inputs():
net = Transformer(encoder_layers=1,
decoder_layers=2,
batch_size=2,
@ -317,16 +372,69 @@ def test_transformer_model_int64_inputs():
model.train(1, dataset, dataset_sink_mode=False)
def test_transformer_model_int64_semi():
"""
Feature: Test Transformer.
Description: int64 input.
Expectation: Successful graph compilation.
"""
set_auto_parallel_context(device_num=8, global_rank=0,
full_batch=True,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
run_transformer_model_int64_inputs()
def test_transformer_model_int64_sp():
"""
Feature: Test Transformer with sharding propagation.
Description: int64 input.
Expectation: Successful graph compilation.
"""
set_auto_parallel_context(device_num=8, global_rank=0,
full_batch=True, search_mode="sharding_propagation",
parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
run_transformer_model_int64_inputs()
def test_transformer_model_head_parallel_only_encoder():
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
run_total_transformer_model_head(e_layer=2, d_layer=0, arg_parallel_config=local_config)
def test_transformer_model_head_parallel_only_encoder_sp():
"""
Feature: Test Transformer with sharding propagation.
Description: only encode.
Expectation: Successful graph compilation.
"""
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
set_auto_parallel_context(search_mode="sharding_propagation",
parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
run_total_transformer_model_head(e_layer=2, d_layer=0, arg_parallel_config=local_config,
mode=ParallelMode.AUTO_PARALLEL)
def test_transformer_model_head_parallel():
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
run_total_transformer_model_head(e_layer=1, d_layer=1, arg_parallel_config=local_config)
def test_transformer_model_head_parallel_sp():
"""
Feature: Test Transformer with sharding propagation.
Description: 1 encode and 1 decode.
Expectation: Successful graph compilation.
"""
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
set_auto_parallel_context(search_mode="sharding_propagation",
parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
run_total_transformer_model_head(e_layer=1, d_layer=1, arg_parallel_config=local_config,
mode=ParallelMode.AUTO_PARALLEL)
def test_transformer_model_head_parallel_decoder():
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
with pytest.raises(ValueError):
@ -353,8 +461,7 @@ def pipeline_single_transformer(grad_accumulation_shard=False):
"""
set_auto_parallel_context(device_num=64,
full_batch=True,
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
context.set_auto_parallel_context(parallel_optimizer_config=
{"gradient_accumulation_shard": grad_accumulation_shard})
@ -394,6 +501,19 @@ def test_pipeline_transformer_gradient_shard_true():
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard True
Expectation: The compile passed
"""
set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
pipeline_single_transformer(grad_accumulation_shard=True)
def test_pipeline_transformer_gradient_shard_true_sp():
"""
Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation with sharding propagation
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard True
Expectation: The compile passed
"""
set_auto_parallel_context(search_mode="sharding_propagation",
parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
pipeline_single_transformer(grad_accumulation_shard=True)
@ -403,6 +523,19 @@ def test_pipeline_transformer_gradient_shard_false():
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False
Expectation: The compile passed
"""
set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
pipeline_single_transformer(grad_accumulation_shard=False)
def test_pipeline_transformer_gradient_shard_false_sp():
"""
Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation with sharding propagation
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False
Expectation: The compile passed
"""
set_auto_parallel_context(search_mode="sharding_propagation",
parallel_mode=ParallelMode.AUTO_PARALLEL)
set_algo_parameters(elementwise_op_strategy_follow=False, fully_use_devices=False)
pipeline_single_transformer(grad_accumulation_shard=False)