fix batch size error

This commit is contained in:
huangxinjing 2021-09-16 09:12:25 +08:00
parent 57b8c8e2d9
commit 0b89d5c9c4
3 changed files with 26 additions and 2 deletions

View File

@ -280,7 +280,8 @@ class _Linear(Cell):
self.expert_num = expert_num
if self.expert_num > 1:
self.expert_flag = True
self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape), name="weight")
self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape, param_init_type),
name="weight")
self.matmul = P.BatchMatMul(transpose_b=transpose_b)
else:
self.expert_flag = False

View File

@ -769,6 +769,7 @@ class MultiHeadAttention(Cell):
parallel_config=default_dpmp_config):
super(MultiHeadAttention, self).__init__()
_check_config(parallel_config)
self.is_parallel_mode = _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.src_seq_length = src_seq_length
self.tgt_seq_length = tgt_seq_length
self.hidden_size = hidden_size
@ -784,7 +785,7 @@ class MultiHeadAttention(Cell):
if num_heads % parallel_config.model_parallel != 0:
raise ValueError(f"The number of heads {num_heads} must be a "
f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
if batch_size % parallel_config.data_parallel != 0:
if self.is_parallel_mode and batch_size % parallel_config.data_parallel != 0:
raise ValueError(f"The batch size {batch_size} must be a "
f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
# Output layer

View File

@ -281,6 +281,28 @@ def test_transformer_wrong_head():
parallel_config=error_test_config)
del net
def test_transformer_wrong_dp_no_error():
set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.DATA_PARALLEL,
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
check_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1, vocab_emb_dp=False)
net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64,
parallel_config=check_config)
del net
def test_transformer_wrong_semi_auto_dp_error():
set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
check_config = TransformerOpParallelConfig(data_parallel=16, model_parallel=1, vocab_emb_dp=False)
with pytest.raises(ValueError):
net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64,
parallel_config=check_config)
del net
def test_encoder():
class NetWithLoss(nn.Cell):
def __init__(self, network):