diff --git a/tests/ut/python/parallel/test_parallel_transformer.py b/tests/ut/python/parallel/test_parallel_transformer.py index 6a03b0559c5..4bba1ac1ef4 100644 --- a/tests/ut/python/parallel/test_parallel_transformer.py +++ b/tests/ut/python/parallel/test_parallel_transformer.py @@ -80,7 +80,7 @@ class TransformerNet(nn.Cell): return self.loss(predict, y, mask) config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False) -pipeline_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, pipeline_stage=4, +pipeline_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, pipeline_stage=4, micro_batch_num=4, vocab_emb_dp=False) @@ -257,14 +257,14 @@ def pipeline_single_transformer(grad_accumulation_shard=False): Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False Expectation: The compile passed """ - set_auto_parallel_context(device_num=32, + set_auto_parallel_context(device_num=64, full_batch=True, pipeline_stages=pipeline_config.pipeline_stage, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) context.set_auto_parallel_context(parallel_optimizer_config= {"gradient_accumulation_shard": grad_accumulation_shard}) - net = Transformer(batch_size=4 // pipeline_config.micro_batch_num, + net = Transformer(batch_size=8 // pipeline_config.micro_batch_num, src_seq_length=20, tgt_seq_length=10, encoder_layers=2, @@ -274,11 +274,11 @@ def pipeline_single_transformer(grad_accumulation_shard=False): ffn_hidden_size=64, parallel_config=pipeline_config) - encoder_input_value = Tensor(np.ones((4, 20, 64)), mstype.float32) - encoder_input_mask = Tensor(np.ones((4, 20, 20)), mstype.float16) - decoder_input_value = Tensor(np.ones((4, 10, 64)), mstype.float32) - decoder_input_mask = Tensor(np.ones((4, 10, 10)), mstype.float16) - memory_mask = Tensor(np.ones((4, 10, 20)), mstype.float16) + encoder_input_value = Tensor(np.ones((8, 20, 64)), mstype.float32) + encoder_input_mask = Tensor(np.ones((8, 20, 20)), mstype.float16) + decoder_input_value = Tensor(np.ones((8, 10, 64)), mstype.float32) + decoder_input_mask = Tensor(np.ones((8, 10, 10)), mstype.float16) + memory_mask = Tensor(np.ones((8, 10, 20)), mstype.float16) net = NetWithLossFiveInputs(net) net = PipelineCell(net, pipeline_config.micro_batch_num) net = _VirtualDatasetCell(net) @@ -313,7 +313,7 @@ def test_pipeline_transformer_gradient_shard_false(): def test_transformer_wrong_head(): - set_auto_parallel_context(device_num=32, + set_auto_parallel_context(device_num=64, full_batch=True, pipeline_stages=pipeline_config.pipeline_stage, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) @@ -343,7 +343,7 @@ def test_transformer_wrong_head(): def test_transformer_wrong_dp_no_error(): - set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.DATA_PARALLEL, + set_auto_parallel_context(device_num=64, full_batch=False, parallel_mode=ParallelMode.DATA_PARALLEL, pipeline_stages=pipeline_config.pipeline_stage, global_rank=0) check_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1, vocab_emb_dp=False) net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2, @@ -353,7 +353,7 @@ def test_transformer_wrong_dp_no_error(): def test_transformer_wrong_semi_auto_dp_error(): - set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, + set_auto_parallel_context(device_num=64, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, pipeline_stages=pipeline_config.pipeline_stage, global_rank=0) check_config = TransformerOpParallelConfig(data_parallel=16, model_parallel=1, vocab_emb_dp=False) with pytest.raises(ValueError):