forked from mindspore-Ecosystem/mindspore
!22905 Fix PanGu Bad Performance
Merge pull request !22905 from huangxinjing/fix_pangu_performance
This commit is contained in:
commit
e63a984416
|
@ -35,7 +35,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
#define MAX_DEVICE_NUM 1024
|
||||
#define MAX_DEVICE_NUM 4096
|
||||
|
||||
constexpr char HCCL_BACKEND[] = "hccl";
|
||||
constexpr char NCCL_BACKEND[] = "nccl";
|
||||
|
|
|
@ -359,6 +359,12 @@ class FixedSparseAttention(nn.Cell):
|
|||
Validator.check_positive_int(seq_length, "seq_length")
|
||||
Validator.check_positive_int(num_different_global_patterns, "num_different_global_patterns")
|
||||
dp, mp = parallel_config.data_parallel, parallel_config.model_parallel
|
||||
if num_heads % mp != 0:
|
||||
raise ValueError(f"The number of heads {num_heads} must be a "
|
||||
f"multiple of parallel_config.model_parallel {mp}.")
|
||||
if batch_size % dp != 0:
|
||||
raise ValueError(f"The batch_size {batch_size} must be a "
|
||||
f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
|
||||
self.seq_length = seq_length
|
||||
self.batch_size = batch_size
|
||||
self.hidden_size = size_per_head * num_heads
|
||||
|
|
|
@ -598,6 +598,9 @@ class MultiHeadAttention(Cell):
|
|||
if num_heads % parallel_config.model_parallel != 0:
|
||||
raise ValueError(f"The number of heads {num_heads} must be a "
|
||||
f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
|
||||
if batch_size % parallel_config.data_parallel != 0:
|
||||
raise ValueError(f"The batch size {num_heads} must be a "
|
||||
f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
|
||||
# Output layer
|
||||
self.projection = _Linear(in_channels=hidden_size,
|
||||
out_channels=hidden_size,
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
"""PanguAlpha model"""
|
||||
import os
|
||||
import copy
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -31,11 +32,15 @@ class EmbeddingLayer(nn.Cell):
|
|||
r"""Embedding layer of the PanGUAlpha Model"""
|
||||
def __init__(self, config):
|
||||
super(EmbeddingLayer, self).__init__()
|
||||
# Only for the pipeline mode, the embedding needs to be row sliced.
|
||||
copied_parallel_config = copy.deepcopy(config.parallel_config)
|
||||
if copied_parallel_config.pipeline_stage > 1:
|
||||
copied_parallel_config.vocab_emb_dp = False
|
||||
self.word_embedding = VocabEmbedding(vocab_size=config.vocab_size,
|
||||
embedding_size=config.hidden_size,
|
||||
param_init=initializer("normal", [config.vocab_size, config.hidden_size],
|
||||
dtype=config.param_init_type),
|
||||
parallel_config=config.parallel_config.embedding_dp_mp_config)
|
||||
parallel_config=copied_parallel_config.embedding_dp_mp_config)
|
||||
self.position_embedding = VocabEmbedding(vocab_size=config.seq_length,
|
||||
embedding_size=config.hidden_size,
|
||||
param_init=initializer("normal",
|
||||
|
@ -180,6 +185,39 @@ class PanGuHead(Cell):
|
|||
return logits
|
||||
|
||||
|
||||
def set_parallel_configure_for_layer(network, layer_id, offset, parallel_config, layers):
|
||||
r"""
|
||||
Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
|
||||
|
||||
|
||||
Args:
|
||||
network(Cell) - Represents the transformer block
|
||||
layer_id(int) - Means the layer index for the current module, counts from zero.
|
||||
offset(int) - Means the layer_index needs a offset, if there are other modules in the net.
|
||||
layers(int) - The total layers used for the model.
|
||||
"""
|
||||
# Used for the pipeline's stages setting
|
||||
# As the final layer is not included here, so we need to manually add here.
|
||||
# original: if set two stages, layers on two stages will be [15, 16+1]
|
||||
# with 1 added, the layers on two stages will be [16, 15 +1]
|
||||
pp_dis = max(int((layers + 1)/ parallel_config.pipeline_stage), 1)
|
||||
# the pipeline stage must be in [0, parallel_config.pipeline_stage - 1]
|
||||
pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
|
||||
network.pipeline_stage = pp_id
|
||||
print(f"pipeline stage id is {pp_id}", flush=True)
|
||||
|
||||
# Used for optimizer's fusion tag
|
||||
dis = max(int(layers / parallel_config.gradient_aggregation_group), 1)
|
||||
if parallel_config.pipeline_stage > 1:
|
||||
# we give the fusion in pipeline mode a fixed value, otherwise the performance may become worse.
|
||||
network.set_comm_fusion(2)
|
||||
else:
|
||||
network.set_comm_fusion(int((layer_id + offset) / dis) + 1)
|
||||
# Used for enabling recomputation of the block
|
||||
if parallel_config.recompute:
|
||||
network.recompute()
|
||||
|
||||
|
||||
class PanguAlpha_Model(Cell):
|
||||
r"""The base backbone of the PanGuAlpha model"""
|
||||
def __init__(self, config):
|
||||
|
@ -188,11 +226,13 @@ class PanguAlpha_Model(Cell):
|
|||
self.embedding = EmbeddingLayer(config)
|
||||
self.config = config
|
||||
self.layernorm = _LayerNorm((config.hidden_size,)).to_float(mstype.float32)
|
||||
self.layernorm.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
|
||||
if config.parallel_config.pipeline_stage > 1:
|
||||
self.layernorm.set_comm_fusion(2)
|
||||
else:
|
||||
self.layernorm.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
|
||||
self.layernorm.shard(((config.parallel_config.data_parallel, 1, 1),))
|
||||
self.layernorm.pipeline_stage = config.parallel_config.pipeline_stage - 1
|
||||
# Configure the shard configure of the Embedding layer
|
||||
self.embedding.set_comm_fusion(0)
|
||||
self.embedding.pipeline_stage = 0
|
||||
|
||||
self.num_layers = config.num_layers
|
||||
|
@ -205,6 +245,7 @@ class PanguAlpha_Model(Cell):
|
|||
seq_length=config.seq_length,
|
||||
attention_dropout_rate=config.dropout_rate,
|
||||
hidden_dropout_rate=config.dropout_rate,
|
||||
lambda_func=set_parallel_configure_for_layer,
|
||||
param_init_type=config.param_init_type,
|
||||
use_past=config.use_past,
|
||||
parallel_config=config.parallel_config).blocks
|
||||
|
@ -216,7 +257,10 @@ class PanguAlpha_Model(Cell):
|
|||
dtype=config.param_init_type),
|
||||
parallel_config=config.parallel_config.embedding_dp_mp_config)
|
||||
self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1
|
||||
self.top_query_embedding.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
|
||||
if config.parallel_config.pipeline_stage > 1:
|
||||
self.top_query_embedding.set_comm_fusion(2)
|
||||
else:
|
||||
self.top_query_embedding.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
|
||||
|
||||
self.top_query_layer = QueryLayer(batch_size=config.batch_size,
|
||||
hidden_size=config.hidden_size,
|
||||
|
@ -312,8 +356,11 @@ class PanguAlphaModel(nn.Cell):
|
|||
def __init__(self, config):
|
||||
super(PanguAlphaModel, self).__init__()
|
||||
# Network head to get logits over vocabulary
|
||||
copied_parallel_config = copy.deepcopy(config.parallel_config)
|
||||
if copied_parallel_config.pipeline_stage > 1:
|
||||
copied_parallel_config.vocab_emb_dp = False
|
||||
self.head = PanGuHead(hidden_size=config.hidden_size,
|
||||
parallel_config=config.parallel_config)
|
||||
parallel_config=copied_parallel_config)
|
||||
self.head.pipeline_stage = config.parallel_config.pipeline_stage - 1
|
||||
self.backbone = PanguAlpha_Model(config)
|
||||
self.backbone.embedding.word_embedding.embedding_table.add_pipeline_stage(self.head.pipeline_stage)
|
||||
|
|
Loading…
Reference in New Issue