forked from mindspore-Ecosystem/mindspore
!22905 Fix PanGu Bad Performance
Merge pull request !22905 from huangxinjing/fix_pangu_performance
This commit is contained in:
commit
e63a984416
|
@ -35,7 +35,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
#define MAX_DEVICE_NUM 1024
|
#define MAX_DEVICE_NUM 4096
|
||||||
|
|
||||||
constexpr char HCCL_BACKEND[] = "hccl";
|
constexpr char HCCL_BACKEND[] = "hccl";
|
||||||
constexpr char NCCL_BACKEND[] = "nccl";
|
constexpr char NCCL_BACKEND[] = "nccl";
|
||||||
|
|
|
@ -359,6 +359,12 @@ class FixedSparseAttention(nn.Cell):
|
||||||
Validator.check_positive_int(seq_length, "seq_length")
|
Validator.check_positive_int(seq_length, "seq_length")
|
||||||
Validator.check_positive_int(num_different_global_patterns, "num_different_global_patterns")
|
Validator.check_positive_int(num_different_global_patterns, "num_different_global_patterns")
|
||||||
dp, mp = parallel_config.data_parallel, parallel_config.model_parallel
|
dp, mp = parallel_config.data_parallel, parallel_config.model_parallel
|
||||||
|
if num_heads % mp != 0:
|
||||||
|
raise ValueError(f"The number of heads {num_heads} must be a "
|
||||||
|
f"multiple of parallel_config.model_parallel {mp}.")
|
||||||
|
if batch_size % dp != 0:
|
||||||
|
raise ValueError(f"The batch_size {batch_size} must be a "
|
||||||
|
f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.hidden_size = size_per_head * num_heads
|
self.hidden_size = size_per_head * num_heads
|
||||||
|
|
|
@ -598,6 +598,9 @@ class MultiHeadAttention(Cell):
|
||||||
if num_heads % parallel_config.model_parallel != 0:
|
if num_heads % parallel_config.model_parallel != 0:
|
||||||
raise ValueError(f"The number of heads {num_heads} must be a "
|
raise ValueError(f"The number of heads {num_heads} must be a "
|
||||||
f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
|
f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.")
|
||||||
|
if batch_size % parallel_config.data_parallel != 0:
|
||||||
|
raise ValueError(f"The batch size {num_heads} must be a "
|
||||||
|
f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
|
||||||
# Output layer
|
# Output layer
|
||||||
self.projection = _Linear(in_channels=hidden_size,
|
self.projection = _Linear(in_channels=hidden_size,
|
||||||
out_channels=hidden_size,
|
out_channels=hidden_size,
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""PanguAlpha model"""
|
"""PanguAlpha model"""
|
||||||
import os
|
import os
|
||||||
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
|
@ -31,11 +32,15 @@ class EmbeddingLayer(nn.Cell):
|
||||||
r"""Embedding layer of the PanGUAlpha Model"""
|
r"""Embedding layer of the PanGUAlpha Model"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(EmbeddingLayer, self).__init__()
|
super(EmbeddingLayer, self).__init__()
|
||||||
|
# Only for the pipeline mode, the embedding needs to be row sliced.
|
||||||
|
copied_parallel_config = copy.deepcopy(config.parallel_config)
|
||||||
|
if copied_parallel_config.pipeline_stage > 1:
|
||||||
|
copied_parallel_config.vocab_emb_dp = False
|
||||||
self.word_embedding = VocabEmbedding(vocab_size=config.vocab_size,
|
self.word_embedding = VocabEmbedding(vocab_size=config.vocab_size,
|
||||||
embedding_size=config.hidden_size,
|
embedding_size=config.hidden_size,
|
||||||
param_init=initializer("normal", [config.vocab_size, config.hidden_size],
|
param_init=initializer("normal", [config.vocab_size, config.hidden_size],
|
||||||
dtype=config.param_init_type),
|
dtype=config.param_init_type),
|
||||||
parallel_config=config.parallel_config.embedding_dp_mp_config)
|
parallel_config=copied_parallel_config.embedding_dp_mp_config)
|
||||||
self.position_embedding = VocabEmbedding(vocab_size=config.seq_length,
|
self.position_embedding = VocabEmbedding(vocab_size=config.seq_length,
|
||||||
embedding_size=config.hidden_size,
|
embedding_size=config.hidden_size,
|
||||||
param_init=initializer("normal",
|
param_init=initializer("normal",
|
||||||
|
@ -180,6 +185,39 @@ class PanGuHead(Cell):
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def set_parallel_configure_for_layer(network, layer_id, offset, parallel_config, layers):
|
||||||
|
r"""
|
||||||
|
Default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network(Cell) - Represents the transformer block
|
||||||
|
layer_id(int) - Means the layer index for the current module, counts from zero.
|
||||||
|
offset(int) - Means the layer_index needs a offset, if there are other modules in the net.
|
||||||
|
layers(int) - The total layers used for the model.
|
||||||
|
"""
|
||||||
|
# Used for the pipeline's stages setting
|
||||||
|
# As the final layer is not included here, so we need to manually add here.
|
||||||
|
# original: if set two stages, layers on two stages will be [15, 16+1]
|
||||||
|
# with 1 added, the layers on two stages will be [16, 15 +1]
|
||||||
|
pp_dis = max(int((layers + 1)/ parallel_config.pipeline_stage), 1)
|
||||||
|
# the pipeline stage must be in [0, parallel_config.pipeline_stage - 1]
|
||||||
|
pp_id = min((layer_id + offset) // pp_dis, parallel_config.pipeline_stage - 1)
|
||||||
|
network.pipeline_stage = pp_id
|
||||||
|
print(f"pipeline stage id is {pp_id}", flush=True)
|
||||||
|
|
||||||
|
# Used for optimizer's fusion tag
|
||||||
|
dis = max(int(layers / parallel_config.gradient_aggregation_group), 1)
|
||||||
|
if parallel_config.pipeline_stage > 1:
|
||||||
|
# we give the fusion in pipeline mode a fixed value, otherwise the performance may become worse.
|
||||||
|
network.set_comm_fusion(2)
|
||||||
|
else:
|
||||||
|
network.set_comm_fusion(int((layer_id + offset) / dis) + 1)
|
||||||
|
# Used for enabling recomputation of the block
|
||||||
|
if parallel_config.recompute:
|
||||||
|
network.recompute()
|
||||||
|
|
||||||
|
|
||||||
class PanguAlpha_Model(Cell):
|
class PanguAlpha_Model(Cell):
|
||||||
r"""The base backbone of the PanGuAlpha model"""
|
r"""The base backbone of the PanGuAlpha model"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
@ -188,11 +226,13 @@ class PanguAlpha_Model(Cell):
|
||||||
self.embedding = EmbeddingLayer(config)
|
self.embedding = EmbeddingLayer(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layernorm = _LayerNorm((config.hidden_size,)).to_float(mstype.float32)
|
self.layernorm = _LayerNorm((config.hidden_size,)).to_float(mstype.float32)
|
||||||
self.layernorm.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
|
if config.parallel_config.pipeline_stage > 1:
|
||||||
|
self.layernorm.set_comm_fusion(2)
|
||||||
|
else:
|
||||||
|
self.layernorm.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
|
||||||
self.layernorm.shard(((config.parallel_config.data_parallel, 1, 1),))
|
self.layernorm.shard(((config.parallel_config.data_parallel, 1, 1),))
|
||||||
self.layernorm.pipeline_stage = config.parallel_config.pipeline_stage - 1
|
self.layernorm.pipeline_stage = config.parallel_config.pipeline_stage - 1
|
||||||
# Configure the shard configure of the Embedding layer
|
# Configure the shard configure of the Embedding layer
|
||||||
self.embedding.set_comm_fusion(0)
|
|
||||||
self.embedding.pipeline_stage = 0
|
self.embedding.pipeline_stage = 0
|
||||||
|
|
||||||
self.num_layers = config.num_layers
|
self.num_layers = config.num_layers
|
||||||
|
@ -205,6 +245,7 @@ class PanguAlpha_Model(Cell):
|
||||||
seq_length=config.seq_length,
|
seq_length=config.seq_length,
|
||||||
attention_dropout_rate=config.dropout_rate,
|
attention_dropout_rate=config.dropout_rate,
|
||||||
hidden_dropout_rate=config.dropout_rate,
|
hidden_dropout_rate=config.dropout_rate,
|
||||||
|
lambda_func=set_parallel_configure_for_layer,
|
||||||
param_init_type=config.param_init_type,
|
param_init_type=config.param_init_type,
|
||||||
use_past=config.use_past,
|
use_past=config.use_past,
|
||||||
parallel_config=config.parallel_config).blocks
|
parallel_config=config.parallel_config).blocks
|
||||||
|
@ -216,7 +257,10 @@ class PanguAlpha_Model(Cell):
|
||||||
dtype=config.param_init_type),
|
dtype=config.param_init_type),
|
||||||
parallel_config=config.parallel_config.embedding_dp_mp_config)
|
parallel_config=config.parallel_config.embedding_dp_mp_config)
|
||||||
self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1
|
self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1
|
||||||
self.top_query_embedding.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
|
if config.parallel_config.pipeline_stage > 1:
|
||||||
|
self.top_query_embedding.set_comm_fusion(2)
|
||||||
|
else:
|
||||||
|
self.top_query_embedding.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
|
||||||
|
|
||||||
self.top_query_layer = QueryLayer(batch_size=config.batch_size,
|
self.top_query_layer = QueryLayer(batch_size=config.batch_size,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
|
@ -312,8 +356,11 @@ class PanguAlphaModel(nn.Cell):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(PanguAlphaModel, self).__init__()
|
super(PanguAlphaModel, self).__init__()
|
||||||
# Network head to get logits over vocabulary
|
# Network head to get logits over vocabulary
|
||||||
|
copied_parallel_config = copy.deepcopy(config.parallel_config)
|
||||||
|
if copied_parallel_config.pipeline_stage > 1:
|
||||||
|
copied_parallel_config.vocab_emb_dp = False
|
||||||
self.head = PanGuHead(hidden_size=config.hidden_size,
|
self.head = PanGuHead(hidden_size=config.hidden_size,
|
||||||
parallel_config=config.parallel_config)
|
parallel_config=copied_parallel_config)
|
||||||
self.head.pipeline_stage = config.parallel_config.pipeline_stage - 1
|
self.head.pipeline_stage = config.parallel_config.pipeline_stage - 1
|
||||||
self.backbone = PanguAlpha_Model(config)
|
self.backbone = PanguAlpha_Model(config)
|
||||||
self.backbone.embedding.word_embedding.embedding_table.add_pipeline_stage(self.head.pipeline_stage)
|
self.backbone.embedding.word_embedding.embedding_table.add_pipeline_stage(self.head.pipeline_stage)
|
||||||
|
|
Loading…
Reference in New Issue