Add transformer

This commit is contained in:
huangxinjing 2021-12-02 20:55:51 +08:00
parent e4fb249c87
commit 8c9b2b93a8
1 changed files with 11 additions and 11 deletions

View File

@ -80,7 +80,7 @@ class TransformerNet(nn.Cell):
return self.loss(predict, y, mask)
config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
pipeline_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, pipeline_stage=4,
pipeline_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, pipeline_stage=4,
micro_batch_num=4, vocab_emb_dp=False)
@ -257,14 +257,14 @@ def pipeline_single_transformer(grad_accumulation_shard=False):
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False
Expectation: The compile passed
"""
set_auto_parallel_context(device_num=32,
set_auto_parallel_context(device_num=64,
full_batch=True,
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
context.set_auto_parallel_context(parallel_optimizer_config=
{"gradient_accumulation_shard": grad_accumulation_shard})
net = Transformer(batch_size=4 // pipeline_config.micro_batch_num,
net = Transformer(batch_size=8 // pipeline_config.micro_batch_num,
src_seq_length=20,
tgt_seq_length=10,
encoder_layers=2,
@ -274,11 +274,11 @@ def pipeline_single_transformer(grad_accumulation_shard=False):
ffn_hidden_size=64,
parallel_config=pipeline_config)
encoder_input_value = Tensor(np.ones((4, 20, 64)), mstype.float32)
encoder_input_mask = Tensor(np.ones((4, 20, 20)), mstype.float16)
decoder_input_value = Tensor(np.ones((4, 10, 64)), mstype.float32)
decoder_input_mask = Tensor(np.ones((4, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((4, 10, 20)), mstype.float16)
encoder_input_value = Tensor(np.ones((8, 20, 64)), mstype.float32)
encoder_input_mask = Tensor(np.ones((8, 20, 20)), mstype.float16)
decoder_input_value = Tensor(np.ones((8, 10, 64)), mstype.float32)
decoder_input_mask = Tensor(np.ones((8, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((8, 10, 20)), mstype.float16)
net = NetWithLossFiveInputs(net)
net = PipelineCell(net, pipeline_config.micro_batch_num)
net = _VirtualDatasetCell(net)
@ -313,7 +313,7 @@ def test_pipeline_transformer_gradient_shard_false():
def test_transformer_wrong_head():
set_auto_parallel_context(device_num=32,
set_auto_parallel_context(device_num=64,
full_batch=True,
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
@ -343,7 +343,7 @@ def test_transformer_wrong_head():
def test_transformer_wrong_dp_no_error():
set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.DATA_PARALLEL,
set_auto_parallel_context(device_num=64, full_batch=False, parallel_mode=ParallelMode.DATA_PARALLEL,
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
check_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1, vocab_emb_dp=False)
net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
@ -353,7 +353,7 @@ def test_transformer_wrong_dp_no_error():
def test_transformer_wrong_semi_auto_dp_error():
set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
set_auto_parallel_context(device_num=64, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
check_config = TransformerOpParallelConfig(data_parallel=16, model_parallel=1, vocab_emb_dp=False)
with pytest.raises(ValueError):