diff --git a/mindspore/parallel/nn/layers.py b/mindspore/parallel/nn/layers.py index dc4a4b4787d..f0070a61b5d 100644 --- a/mindspore/parallel/nn/layers.py +++ b/mindspore/parallel/nn/layers.py @@ -280,7 +280,8 @@ class _Linear(Cell): self.expert_num = expert_num if self.expert_num > 1: self.expert_flag = True - self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape), name="weight") + self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape, param_init_type), + name="weight") self.matmul = P.BatchMatMul(transpose_b=transpose_b) else: self.expert_flag = False diff --git a/mindspore/parallel/nn/transformer.py b/mindspore/parallel/nn/transformer.py index 1c80f072f6c..5cb7613aac2 100644 --- a/mindspore/parallel/nn/transformer.py +++ b/mindspore/parallel/nn/transformer.py @@ -769,6 +769,7 @@ class MultiHeadAttention(Cell): parallel_config=default_dpmp_config): super(MultiHeadAttention, self).__init__() _check_config(parallel_config) + self.is_parallel_mode = _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) self.src_seq_length = src_seq_length self.tgt_seq_length = tgt_seq_length self.hidden_size = hidden_size @@ -784,7 +785,7 @@ class MultiHeadAttention(Cell): if num_heads % parallel_config.model_parallel != 0: raise ValueError(f"The number of heads {num_heads} must be a " f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.") - if batch_size % parallel_config.data_parallel != 0: + if self.is_parallel_mode and batch_size % parallel_config.data_parallel != 0: raise ValueError(f"The batch size {batch_size} must be a " f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.") # Output layer diff --git a/tests/ut/python/parallel/test_parallel_transformer.py b/tests/ut/python/parallel/test_parallel_transformer.py index 5be027aa0b9..2d6fd79029a 100644 --- a/tests/ut/python/parallel/test_parallel_transformer.py +++ b/tests/ut/python/parallel/test_parallel_transformer.py @@ -281,6 +281,28 @@ def test_transformer_wrong_head(): parallel_config=error_test_config) del net + +def test_transformer_wrong_dp_no_error(): + set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.DATA_PARALLEL, + pipeline_stages=pipeline_config.pipeline_stage, global_rank=0) + check_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1, vocab_emb_dp=False) + net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2, + decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64, + parallel_config=check_config) + del net + + +def test_transformer_wrong_semi_auto_dp_error(): + set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, + pipeline_stages=pipeline_config.pipeline_stage, global_rank=0) + check_config = TransformerOpParallelConfig(data_parallel=16, model_parallel=1, vocab_emb_dp=False) + with pytest.raises(ValueError): + net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2, + decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64, + parallel_config=check_config) + del net + + def test_encoder(): class NetWithLoss(nn.Cell): def __init__(self, network):