!22905 Fix PanGu Bad Performance

Merge pull request !22905 from huangxinjing/fix_pangu_performance
This commit is contained in:
i-robot 2021-09-06 13:08:41 +00:00 committed by Gitee
commit e63a984416
4 changed files with 62 additions and 6 deletions

View File

@ -35,7 +35,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
#define MAX_DEVICE_NUM 1024 #define MAX_DEVICE_NUM 4096
constexpr char HCCL_BACKEND[] = "hccl"; constexpr char HCCL_BACKEND[] = "hccl";
constexpr char NCCL_BACKEND[] = "nccl"; constexpr char NCCL_BACKEND[] = "nccl";

View File

@ -359,6 +359,12 @@ class FixedSparseAttention(nn.Cell):
Validator.check_positive_int(seq_length, "seq_length") Validator.check_positive_int(seq_length, "seq_length")
Validator.check_positive_int(num_different_global_patterns, "num_different_global_patterns") Validator.check_positive_int(num_different_global_patterns, "num_different_global_patterns")
dp, mp = parallel_config.data_parallel, parallel_config.model_parallel dp, mp = parallel_config.data_parallel, parallel_config.model_parallel
if num_heads % mp != 0:
raise ValueError(f"The number of heads {num_heads} must be a "
f"multiple of parallel_config.model_parallel {mp}.")
if batch_size % dp != 0:
raise ValueError(f"The batch_size {batch_size} must be a "
f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
self.seq_length = seq_length self.seq_length = seq_length
self.batch_size = batch_size self.batch_size = batch_size
self.hidden_size = size_per_head * num_heads self.hidden_size = size_per_head * num_heads

View File

@ -598,6 +598,9 @@ class MultiHeadAttention(Cell):
if num_heads % parallel_config.model_parallel != 0: if num_heads % parallel_config.model_parallel != 0:
raise ValueError(f"The number of heads {num_heads} must be a " raise ValueError(f"The number of heads {num_heads} must be a "
f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.") f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
if batch_size % parallel_config.data_parallel != 0:
raise ValueError(f"The batch size {num_heads} must be a "
f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
# Output layer # Output layer
self.projection = _Linear(in_channels=hidden_size, self.projection = _Linear(in_channels=hidden_size,
out_channels=hidden_size, out_channels=hidden_size,

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""PanguAlpha model""" """PanguAlpha model"""
import os import os
import copy
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
@ -31,11 +32,15 @@ class EmbeddingLayer(nn.Cell):
r"""Embedding layer of the PanGUAlpha Model""" r"""Embedding layer of the PanGUAlpha Model"""
def __init__(self, config): def __init__(self, config):
super(EmbeddingLayer, self).__init__() super(EmbeddingLayer, self).__init__()
# Only for the pipeline mode, the embedding needs to be row sliced.
copied_parallel_config = copy.deepcopy(config.parallel_config)
if copied_parallel_config.pipeline_stage > 1:
copied_parallel_config.vocab_emb_dp = False
self.word_embedding = VocabEmbedding(vocab_size=config.vocab_size, self.word_embedding = VocabEmbedding(vocab_size=config.vocab_size,
embedding_size=config.hidden_size, embedding_size=config.hidden_size,
param_init=initializer("normal", [config.vocab_size, config.hidden_size], param_init=initializer("normal", [config.vocab_size, config.hidden_size],
dtype=config.param_init_type), dtype=config.param_init_type),
parallel_config=config.parallel_config.embedding_dp_mp_config) parallel_config=copied_parallel_config.embedding_dp_mp_config)
self.position_embedding = VocabEmbedding(vocab_size=config.seq_length, self.position_embedding = VocabEmbedding(vocab_size=config.seq_length,
embedding_size=config.hidden_size, embedding_size=config.hidden_size,
param_init=initializer("normal", param_init=initializer("normal",
@ -180,6 +185,39 @@ class PanGuHead(Cell):
return logits return logits
def set_parallel_configure_for_layer(network, layer_id, offset, parallel_config, layers):
r"""
Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
Args:
network(Cell) - Represents the transformer block
layer_id(int) - Means the layer index for the current module, counts from zero.
offset(int) - Means the layer_index needs a offset, if there are other modules in the net.
layers(int) - The total layers used for the model.
"""
# Used for the pipeline's stages setting
# As the final layer is not included here, so we need to manually add here.
# original: if set two stages, layers on two stages will be [15, 16+1]
# with 1 added, the layers on two stages will be [16, 15 +1]
pp_dis = max(int((layers + 1)/ parallel_config.pipeline_stage), 1)
# the pipeline stage must be in [0, parallel_config.pipeline_stage - 1]
pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
network.pipeline_stage = pp_id
print(f"pipeline stage id is {pp_id}", flush=True)
# Used for optimizer's fusion tag
dis = max(int(layers / parallel_config.gradient_aggregation_group), 1)
if parallel_config.pipeline_stage > 1:
# we give the fusion in pipeline mode a fixed value, otherwise the performance may become worse.
network.set_comm_fusion(2)
else:
network.set_comm_fusion(int((layer_id + offset) / dis) + 1)
# Used for enabling recomputation of the block
if parallel_config.recompute:
network.recompute()
class PanguAlpha_Model(Cell): class PanguAlpha_Model(Cell):
r"""The base backbone of the PanGuAlpha model""" r"""The base backbone of the PanGuAlpha model"""
def __init__(self, config): def __init__(self, config):
@ -188,11 +226,13 @@ class PanguAlpha_Model(Cell):
self.embedding = EmbeddingLayer(config) self.embedding = EmbeddingLayer(config)
self.config = config self.config = config
self.layernorm = _LayerNorm((config.hidden_size,)).to_float(mstype.float32) self.layernorm = _LayerNorm((config.hidden_size,)).to_float(mstype.float32)
self.layernorm.set_comm_fusion(config.parallel_config.gradient_aggregation_group) if config.parallel_config.pipeline_stage > 1:
self.layernorm.set_comm_fusion(2)
else:
self.layernorm.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
self.layernorm.shard(((config.parallel_config.data_parallel, 1, 1),)) self.layernorm.shard(((config.parallel_config.data_parallel, 1, 1),))
self.layernorm.pipeline_stage = config.parallel_config.pipeline_stage - 1 self.layernorm.pipeline_stage = config.parallel_config.pipeline_stage - 1
# Configure the shard configure of the Embedding layer # Configure the shard configure of the Embedding layer
self.embedding.set_comm_fusion(0)
self.embedding.pipeline_stage = 0 self.embedding.pipeline_stage = 0
self.num_layers = config.num_layers self.num_layers = config.num_layers
@ -205,6 +245,7 @@ class PanguAlpha_Model(Cell):
seq_length=config.seq_length, seq_length=config.seq_length,
attention_dropout_rate=config.dropout_rate, attention_dropout_rate=config.dropout_rate,
hidden_dropout_rate=config.dropout_rate, hidden_dropout_rate=config.dropout_rate,
lambda_func=set_parallel_configure_for_layer,
param_init_type=config.param_init_type, param_init_type=config.param_init_type,
use_past=config.use_past, use_past=config.use_past,
parallel_config=config.parallel_config).blocks parallel_config=config.parallel_config).blocks
@ -216,7 +257,10 @@ class PanguAlpha_Model(Cell):
dtype=config.param_init_type), dtype=config.param_init_type),
parallel_config=config.parallel_config.embedding_dp_mp_config) parallel_config=config.parallel_config.embedding_dp_mp_config)
self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1 self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1
self.top_query_embedding.set_comm_fusion(config.parallel_config.gradient_aggregation_group) if config.parallel_config.pipeline_stage > 1:
self.top_query_embedding.set_comm_fusion(2)
else:
self.top_query_embedding.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
self.top_query_layer = QueryLayer(batch_size=config.batch_size, self.top_query_layer = QueryLayer(batch_size=config.batch_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
@ -312,8 +356,11 @@ class PanguAlphaModel(nn.Cell):
def __init__(self, config): def __init__(self, config):
super(PanguAlphaModel, self).__init__() super(PanguAlphaModel, self).__init__()
# Network head to get logits over vocabulary # Network head to get logits over vocabulary
copied_parallel_config = copy.deepcopy(config.parallel_config)
if copied_parallel_config.pipeline_stage > 1:
copied_parallel_config.vocab_emb_dp = False
self.head = PanGuHead(hidden_size=config.hidden_size, self.head = PanGuHead(hidden_size=config.hidden_size,
parallel_config=config.parallel_config) parallel_config=copied_parallel_config)
self.head.pipeline_stage = config.parallel_config.pipeline_stage - 1 self.head.pipeline_stage = config.parallel_config.pipeline_stage - 1
self.backbone = PanguAlpha_Model(config) self.backbone = PanguAlpha_Model(config)
self.backbone.embedding.word_embedding.embedding_table.add_pipeline_stage(self.head.pipeline_stage) self.backbone.embedding.word_embedding.embedding_table.add_pipeline_stage(self.head.pipeline_stage)