forked from mindspore-Ecosystem/mindspore
Add transformer
This commit is contained in:
parent
e4fb249c87
commit
8c9b2b93a8
|
@ -80,7 +80,7 @@ class TransformerNet(nn.Cell):
|
|||
return self.loss(predict, y, mask)
|
||||
|
||||
config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
|
||||
pipeline_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, pipeline_stage=4,
|
||||
pipeline_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, pipeline_stage=4,
|
||||
micro_batch_num=4, vocab_emb_dp=False)
|
||||
|
||||
|
||||
|
@ -257,14 +257,14 @@ def pipeline_single_transformer(grad_accumulation_shard=False):
|
|||
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False
|
||||
Expectation: The compile passed
|
||||
"""
|
||||
set_auto_parallel_context(device_num=32,
|
||||
set_auto_parallel_context(device_num=64,
|
||||
full_batch=True,
|
||||
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
context.set_auto_parallel_context(parallel_optimizer_config=
|
||||
{"gradient_accumulation_shard": grad_accumulation_shard})
|
||||
|
||||
net = Transformer(batch_size=4 // pipeline_config.micro_batch_num,
|
||||
net = Transformer(batch_size=8 // pipeline_config.micro_batch_num,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
encoder_layers=2,
|
||||
|
@ -274,11 +274,11 @@ def pipeline_single_transformer(grad_accumulation_shard=False):
|
|||
ffn_hidden_size=64,
|
||||
parallel_config=pipeline_config)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((4, 20, 64)), mstype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((4, 20, 20)), mstype.float16)
|
||||
decoder_input_value = Tensor(np.ones((4, 10, 64)), mstype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((4, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((4, 10, 20)), mstype.float16)
|
||||
encoder_input_value = Tensor(np.ones((8, 20, 64)), mstype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((8, 20, 20)), mstype.float16)
|
||||
decoder_input_value = Tensor(np.ones((8, 10, 64)), mstype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((8, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((8, 10, 20)), mstype.float16)
|
||||
net = NetWithLossFiveInputs(net)
|
||||
net = PipelineCell(net, pipeline_config.micro_batch_num)
|
||||
net = _VirtualDatasetCell(net)
|
||||
|
@ -313,7 +313,7 @@ def test_pipeline_transformer_gradient_shard_false():
|
|||
|
||||
|
||||
def test_transformer_wrong_head():
|
||||
set_auto_parallel_context(device_num=32,
|
||||
set_auto_parallel_context(device_num=64,
|
||||
full_batch=True,
|
||||
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
|
@ -343,7 +343,7 @@ def test_transformer_wrong_head():
|
|||
|
||||
|
||||
def test_transformer_wrong_dp_no_error():
|
||||
set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
set_auto_parallel_context(device_num=64, full_batch=False, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
|
||||
check_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1, vocab_emb_dp=False)
|
||||
net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
|
||||
|
@ -353,7 +353,7 @@ def test_transformer_wrong_dp_no_error():
|
|||
|
||||
|
||||
def test_transformer_wrong_semi_auto_dp_error():
|
||||
set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||
set_auto_parallel_context(device_num=64, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
|
||||
check_config = TransformerOpParallelConfig(data_parallel=16, model_parallel=1, vocab_emb_dp=False)
|
||||
with pytest.raises(ValueError):
|
||||
|
|
Loading…
Reference in New Issue