forked from mindspore-Ecosystem/mindspore
!23571 Fix mode check for the transformer model
Merge pull request !23571 from huangxinjing/fix_transformer_mode_check
This commit is contained in:
commit
09dd495070
|
@ -280,7 +280,8 @@ class _Linear(Cell):
|
|||
self.expert_num = expert_num
|
||||
if self.expert_num > 1:
|
||||
self.expert_flag = True
|
||||
self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape), name="weight")
|
||||
self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape, param_init_type),
|
||||
name="weight")
|
||||
self.matmul = P.BatchMatMul(transpose_b=transpose_b)
|
||||
else:
|
||||
self.expert_flag = False
|
||||
|
|
|
@ -769,6 +769,7 @@ class MultiHeadAttention(Cell):
|
|||
parallel_config=default_dpmp_config):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
_check_config(parallel_config)
|
||||
self.is_parallel_mode = _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
self.src_seq_length = src_seq_length
|
||||
self.tgt_seq_length = tgt_seq_length
|
||||
self.hidden_size = hidden_size
|
||||
|
@ -784,7 +785,7 @@ class MultiHeadAttention(Cell):
|
|||
if num_heads % parallel_config.model_parallel != 0:
|
||||
raise ValueError(f"The number of heads {num_heads} must be a "
|
||||
f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
|
||||
if batch_size % parallel_config.data_parallel != 0:
|
||||
if self.is_parallel_mode and batch_size % parallel_config.data_parallel != 0:
|
||||
raise ValueError(f"The batch size {batch_size} must be a "
|
||||
f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
|
||||
# Output layer
|
||||
|
|
|
@ -281,6 +281,28 @@ def test_transformer_wrong_head():
|
|||
parallel_config=error_test_config)
|
||||
del net
|
||||
|
||||
|
||||
def test_transformer_wrong_dp_no_error():
|
||||
set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
|
||||
check_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1, vocab_emb_dp=False)
|
||||
net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
|
||||
decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64,
|
||||
parallel_config=check_config)
|
||||
del net
|
||||
|
||||
|
||||
def test_transformer_wrong_semi_auto_dp_error():
|
||||
set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
|
||||
check_config = TransformerOpParallelConfig(data_parallel=16, model_parallel=1, vocab_emb_dp=False)
|
||||
with pytest.raises(ValueError):
|
||||
net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
|
||||
decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64,
|
||||
parallel_config=check_config)
|
||||
del net
|
||||
|
||||
|
||||
def test_encoder():
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
|
|
Loading…
Reference in New Issue