!22905 Fix PanGu Bad Performance

Merge pull request !22905 from huangxinjing/fix_pangu_performance
This commit is contained in:
i-robot 2021-09-06 13:08:41 +00:00 committed by Gitee
commit e63a984416
4 changed files with 62 additions and 6 deletions

View File

@ -35,7 +35,7 @@
namespace mindspore {
namespace parallel {
#define MAX_DEVICE_NUM 1024
#define MAX_DEVICE_NUM 4096
constexpr char HCCL_BACKEND[] = "hccl";
constexpr char NCCL_BACKEND[] = "nccl";

View File

@ -359,6 +359,12 @@ class FixedSparseAttention(nn.Cell):
Validator.check_positive_int(seq_length, "seq_length")
Validator.check_positive_int(num_different_global_patterns, "num_different_global_patterns")
dp, mp = parallel_config.data_parallel, parallel_config.model_parallel
if num_heads % mp != 0:
raise ValueError(f"The number of heads {num_heads} must be a "
f"multiple of parallel_config.model_parallel {mp}.")
if batch_size % dp != 0:
raise ValueError(f"The batch_size {batch_size} must be a "
f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
self.seq_length = seq_length
self.batch_size = batch_size
self.hidden_size = size_per_head * num_heads

View File

@ -598,6 +598,9 @@ class MultiHeadAttention(Cell):
if num_heads % parallel_config.model_parallel != 0:
raise ValueError(f"The number of heads {num_heads} must be a "
f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
if batch_size % parallel_config.data_parallel != 0:
raise ValueError(f"The batch size {num_heads} must be a "
f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
# Output layer
self.projection = _Linear(in_channels=hidden_size,
out_channels=hidden_size,

View File

@ -14,6 +14,7 @@
# ============================================================================
"""PanguAlpha model"""
import os
import copy
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
@ -31,11 +32,15 @@ class EmbeddingLayer(nn.Cell):
r"""Embedding layer of the PanGUAlpha Model"""
def __init__(self, config):
super(EmbeddingLayer, self).__init__()
# Only for the pipeline mode, the embedding needs to be row sliced.
copied_parallel_config = copy.deepcopy(config.parallel_config)
if copied_parallel_config.pipeline_stage > 1:
copied_parallel_config.vocab_emb_dp = False
self.word_embedding = VocabEmbedding(vocab_size=config.vocab_size,
embedding_size=config.hidden_size,
param_init=initializer("normal", [config.vocab_size, config.hidden_size],
dtype=config.param_init_type),
parallel_config=config.parallel_config.embedding_dp_mp_config)
parallel_config=copied_parallel_config.embedding_dp_mp_config)
self.position_embedding = VocabEmbedding(vocab_size=config.seq_length,
embedding_size=config.hidden_size,
param_init=initializer("normal",
@ -180,6 +185,39 @@ class PanGuHead(Cell):
return logits
def set_parallel_configure_for_layer(network, layer_id, offset, parallel_config, layers):
r"""
Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
Args:
network(Cell) - Represents the transformer block
layer_id(int) - Means the layer index for the current module, counts from zero.
offset(int) - Means the layer_index needs a offset, if there are other modules in the net.
layers(int) - The total layers used for the model.
"""
# Used for the pipeline's stages setting
# As the final layer is not included here, so we need to manually add here.
# original: if set two stages, layers on two stages will be [15, 16+1]
# with 1 added, the layers on two stages will be [16, 15 +1]
pp_dis = max(int((layers + 1)/ parallel_config.pipeline_stage), 1)
# the pipeline stage must be in [0, parallel_config.pipeline_stage - 1]
pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
network.pipeline_stage = pp_id
print(f"pipeline stage id is {pp_id}", flush=True)
# Used for optimizer's fusion tag
dis = max(int(layers / parallel_config.gradient_aggregation_group), 1)
if parallel_config.pipeline_stage > 1:
# we give the fusion in pipeline mode a fixed value, otherwise the performance may become worse.
network.set_comm_fusion(2)
else:
network.set_comm_fusion(int((layer_id + offset) / dis) + 1)
# Used for enabling recomputation of the block
if parallel_config.recompute:
network.recompute()
class PanguAlpha_Model(Cell):
r"""The base backbone of the PanGuAlpha model"""
def __init__(self, config):
@ -188,11 +226,13 @@ class PanguAlpha_Model(Cell):
self.embedding = EmbeddingLayer(config)
self.config = config
self.layernorm = _LayerNorm((config.hidden_size,)).to_float(mstype.float32)
self.layernorm.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
if config.parallel_config.pipeline_stage > 1:
self.layernorm.set_comm_fusion(2)
else:
self.layernorm.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
self.layernorm.shard(((config.parallel_config.data_parallel, 1, 1),))
self.layernorm.pipeline_stage = config.parallel_config.pipeline_stage - 1
# Configure the shard configure of the Embedding layer
self.embedding.set_comm_fusion(0)
self.embedding.pipeline_stage = 0
self.num_layers = config.num_layers
@ -205,6 +245,7 @@ class PanguAlpha_Model(Cell):
seq_length=config.seq_length,
attention_dropout_rate=config.dropout_rate,
hidden_dropout_rate=config.dropout_rate,
lambda_func=set_parallel_configure_for_layer,
param_init_type=config.param_init_type,
use_past=config.use_past,
parallel_config=config.parallel_config).blocks
@ -216,7 +257,10 @@ class PanguAlpha_Model(Cell):
dtype=config.param_init_type),
parallel_config=config.parallel_config.embedding_dp_mp_config)
self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1
self.top_query_embedding.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
if config.parallel_config.pipeline_stage > 1:
self.top_query_embedding.set_comm_fusion(2)
else:
self.top_query_embedding.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
self.top_query_layer = QueryLayer(batch_size=config.batch_size,
hidden_size=config.hidden_size,
@ -312,8 +356,11 @@ class PanguAlphaModel(nn.Cell):
def __init__(self, config):
super(PanguAlphaModel, self).__init__()
# Network head to get logits over vocabulary
copied_parallel_config = copy.deepcopy(config.parallel_config)
if copied_parallel_config.pipeline_stage > 1:
copied_parallel_config.vocab_emb_dp = False
self.head = PanGuHead(hidden_size=config.hidden_size,
parallel_config=config.parallel_config)
parallel_config=copied_parallel_config)
self.head.pipeline_stage = config.parallel_config.pipeline_stage - 1
self.backbone = PanguAlpha_Model(config)
self.backbone.embedding.word_embedding.embedding_table.add_pipeline_stage(self.head.pipeline_stage)