pangu alpha modify

This commit is contained in:
yao_yf 2021-07-02 16:57:11 +08:00
parent fb6ec96862
commit 327ba1962a
9 changed files with 220 additions and 443 deletions

View File

@ -1330,10 +1330,10 @@ class Cell(Cell_):
Raises:
RuntimeError: If there is a parameter does not belong to any stage.
"""
from mindspore.communication import get_group_size, get_rank
from mindspore.parallel._utils import _get_global_rank, _get_device_num
stage_num = context.get_auto_parallel_context("pipeline_stages")
device_num = get_group_size()
rank_id = get_rank()
device_num = _get_device_num()
rank_id = _get_global_rank()
per_stage_devices = device_num // stage_num
current_stage = rank_id // per_stage_devices
params = []

View File

@ -20,7 +20,7 @@ from mindspore._checkparam import Validator
from mindspore.common.dtype import pytype_to_dtype
from .. import context, nn
from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes, _get_pipeline_stages
from ..ops import operations as P
@ -158,7 +158,7 @@ def connect_network_with_dataset(network, dataset_helper):
network = dataset_iter.__network_manage__[key]
else:
if _need_to_full():
device_num = _get_device_num()
device_num = _get_device_num() // _get_pipeline_stages()
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)
@ -358,7 +358,7 @@ class _DatasetIterGE(_DatasetIter):
self.sink_count = self.get_sink_count(dataset)
batch_expand_num = 1
if _need_to_full():
batch_expand_num = _get_device_num()
batch_expand_num = _get_device_num() // _get_pipeline_stages()
tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num)
def op():
@ -393,7 +393,7 @@ class _DatasetIterMSLoopSink(_DatasetIter):
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
if _need_to_full():
device_num = _get_device_num()
device_num = _get_device_num() // _get_pipeline_stages()
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
def op():

View File

@ -98,7 +98,6 @@ def load_model(args_opt):
dropout_rate=0.0,
compute_dtype=mstype.float16,
use_past=use_past,
self_layernorm=True,
stage_num=args_opt.stage_num,
micro_size=args_opt.micro_size,
eod_reset=False,

View File

@ -205,7 +205,6 @@ def generate_increment(model, origin_inputs, config):
log_probs = logits.reshape(1, config.vocab_size)
# Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results
log_probs = log_probs.asnumpy().reshape(1, config.vocab_size)
log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty
p, p_args = sampler(log_probs_revised, top_p, top_k_num, use_pynative)

View File

@ -220,51 +220,9 @@ class Output(nn.Cell):
output = self.dropout(output)
return output
class AttentionMask(nn.Cell):
r"""
Get the attention matrix for self-attention module
Args:
config(PanguAlphaConfig): the config of network
Inputs:
input_mask: the mask indicating whether each position is a valid input
Returns:
attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
"""
def __init__(self, config):
super(AttentionMask, self).__init__()
self.reshape = P.Reshape()
self.mul = P.BatchMatMul().shard(
((config.dp, 1, 1), (config.dp, 1, 1))) # yzz: use 64, 1, 1?
self.expand_dim = P.ExpandDims().shard(((1, 1),))
ones = np.ones(shape=(config.seq_length, config.seq_length))
# Default lower triangle mask matrix
self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
self.multiply = P.Mul().shard(((config.dp, 1, 1), (1, 1, 1)))
def construct(self, input_mask):
r"""
Generate the attention mask matrix.
"""
input_shape = P.Shape()(input_mask)
shape_right = (input_shape[0], 1, input_shape[1])
shape_left = input_shape + (1,)
# Mask the padded inputs
mask_left = self.reshape(input_mask, shape_left)
mask_right = self.reshape(input_mask, shape_right)
attention_mask = self.mul(mask_left, mask_right)
lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0)
# [bs, seq_length, seq_length]
attention_mask = self.multiply(
attention_mask, lower_traiangle)
return attention_mask
class EmbeddingLookup(nn.Cell):
"""
The embedding lookup table for vocabulary
Args:
config(PanguAlphaConfig): the config of network
Inputs:
input_ids: the tokenized inputs with datatype int32
Returns:
@ -272,40 +230,11 @@ class EmbeddingLookup(nn.Cell):
seq_length, embedding_size)
self.embedding_table: Tensor, the embedding table for the vocabulary
"""
def __init__(self, config):
def __init__(self):
super(EmbeddingLookup, self).__init__()
self.vocab_size = config.vocab_size
self.embedding_size = config.embedding_size
self.gather = P.GatherV2()
if config.word_emb_dp:
self.gather = P.GatherV2().shard(((1, 1), (config.dp, 1)))
else:
self.gather = P.GatherV2().shard(((config.mp, 1), (1, 1)))
self.shape = (-1, config.seq_length, config.embedding_size)
if config.stage_num > 1:
self.construct = self.construct_pipeline
self.gather.add_prim_attr("parameter_start", 0)
else:
if config.load_ckpt_path:
# Loading the embedding table from the ckpt path:
embedding_path = os.path.join(config.load_ckpt_path, 'word_embedding.npy')
if os.path.exists(embedding_path):
e_table = np.load(embedding_path)
e_table = Tensor(e_table, mstype.float32)
self.embedding_table = Parameter(e_table, name="embedding_table")
else:
raise ValueError(f"{embedding_path} file not exits, "
f"please check whether word_embedding file exist.")
else:
self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]),
name="embedding_table")
self.construct = self.construct_no_pipeline
def construct_no_pipeline(self, input_ids):
output = self.gather(self.embedding_table, input_ids, 0)
return output, self.embedding_table
def construct_pipeline(self, input_ids, table):
def construct(self, input_ids, table):
output = self.gather(table, input_ids, 0)
return output
@ -320,8 +249,6 @@ class Attention(nn.Cell):
"""
def __init__(self, config, scale=1.0, layer_idx=None):
super(Attention, self).__init__()
# Attention mask matrix
self.get_attention_mask = AttentionMask(config)
# Output layer
self.projection = Mapping(config, config.embedding_size,
config.embedding_size, scale)
@ -625,19 +552,8 @@ class Decoder(nn.Cell):
def __init__(self, config, layer_idx):
super(Decoder, self).__init__()
scale = 1 / math.sqrt(2.0 * config.num_layers)
if config.self_layernorm:
self.layernorm1 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
self.layernorm2 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
else:
self.layernorm1 = nn.LayerNorm((config.embedding_size,)).to_float(mstype.float32)
self.layernorm1.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
self.layernorm2 = nn.LayerNorm((config.embedding_size,)).to_float(mstype.float32)
self.layernorm2.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
self.layernorm1.gamma.parallel_optimizer = False
self.layernorm1.beta.parallel_optimizer = False
self.layernorm2.gamma.parallel_optimizer = False
self.layernorm2.beta.parallel_optimizer = False
self.layernorm1 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
self.layernorm2 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
self.attention = Attention(config, scale, layer_idx)
# Feed Forward Network, FFN
@ -715,16 +631,40 @@ class Decoder(nn.Cell):
output = self.add(x, mlp_logit)
return output, layer_present
class PanguAlpha_EmbeddingPipeLine(nn.Cell):
class EmbeddingCell(nn.Cell):
"""
PanguAlpha_EmbeddingPipeLine
EmbeddingCell
"""
def __init__(self, config):
super(PanguAlpha_EmbeddingPipeLine, self).__init__()
self.word_embedding = EmbeddingLookup(config)
self.position_embedding = nn.Embedding(config.seq_length,
config.embedding_size,
embedding_table=Normal(0.02))
super(EmbeddingCell, self).__init__()
self.word_embedding = EmbeddingLookup().set_comm_fusion(1)
if config.word_emb_dp:
self.word_embedding.gather.shard(((1, 1), (config.dp, 1)))
else:
self.word_embedding.gather.shard(((config.mp, 1), (1, 1)))
if config.stage_num > 1:
self.position_embedding = nn.Embedding(config.seq_length,
config.embedding_size,
embedding_table=Normal(0.02)).set_comm_fusion(1)
else:
# Position embedding
if config.load_ckpt_path:
# Loading the embedding table from the ckpt path:
embedding_path = os.path.join(config.load_ckpt_path, 'position_embedding.npy')
if os.path.exists(embedding_path):
p_table = np.load(embedding_path)
position_table_param = Tensor(p_table, mstype.float32)
else:
raise ValueError(f"{embedding_path} file not exits, "
f"please check whether position_embedding file exit.")
else:
position_table_param = TruncatedNormal(0.02)
# Position embedding
self.position_embedding = nn.Embedding(
config.seq_length,
config.embedding_size,
embedding_table=position_table_param).set_comm_fusion(1)
self.position_embedding.embedding_table.parallel_optimizer = False
self.position_embedding.gather.shard(((1, 1), (config.dp,)))
self.position_embedding.expand.shard(((config.dp, 1),))
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
@ -746,16 +686,15 @@ class PanguAlpha_EmbeddingPipeLine(nn.Cell):
return hidden_states
class PanguAlpha_Mask(nn.Cell):
class MaskCell(nn.Cell):
"""
PanguAlpha_Mask
MaskCell
"""
def __init__(self, config):
super(PanguAlpha_Mask, self).__init__()
self.get_attention_mask = AttentionMask(config)
super(MaskCell, self).__init__()
self.dtype = config.compute_dtype
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
def construct(self, input_mask, attention_mask):
def construct(self, attention_mask):
attention_mask = self.expand_dims(attention_mask, 1)
return attention_mask
@ -914,7 +853,7 @@ class QueryLayer(nn.Cell):
output = self.add(x, mlp_logit)
return output, layer_present
class Embedding(nn.Cell):
class PanguAlphaEmbedding(nn.Cell):
"""
Input embedding, i.e., word embedding and position embedding
Args:
@ -931,174 +870,21 @@ class Embedding(nn.Cell):
embedding_table: Tensor, embedding_table with shape of (vocab_size, embedding_size)
"""
def __init__(self, config):
super(Embedding, self).__init__()
self.get_attention_mask = AttentionMask(config)
# Word embedding
self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1)
if config.load_ckpt_path:
# Loading the embedding table from the ckpt path:
embedding_path = os.path.join(config.load_ckpt_path, 'position_embedding.npy')
if os.path.exists(embedding_path):
p_table = np.load(embedding_path)
position_table_param = Tensor(p_table, mstype.float32)
else:
raise ValueError(f"{embedding_path} file not exits, please check whether position_embedding file exit.")
else:
position_table_param = TruncatedNormal(0.02)
super(PanguAlphaEmbedding, self).__init__()
self.embedding = EmbeddingCell(config)
if config.stage_num > 1:
self.embedding.pipeline_stage = 0
self.mask = MaskCell(config)
# Position embedding
self.position_embedding = nn.Embedding(
config.seq_length,
config.embedding_size,
embedding_table=position_table_param).set_comm_fusion(1)
self.word_embedding.embedding_table.parallel_optimizer = False
self.position_embedding.embedding_table.parallel_optimizer = False
self.position_embedding.gather.shard(((1, 1), (config.dp,)))
self.position_embedding.expand.shard(((config.dp, 1),))
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
self.dropout = Dropout(1 - config.dropout_rate)
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
self.eod_reset = config.eod_reset
self.use_past = config.use_past
self.is_first_iteration = True
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, valid_index=None):
def construct(self, input_ids, input_mask, table, input_position, attention_mask, valid_index=None):
"""
Calculate input embeddings via input token ids and input position
"""
# Word embedding
input_embedding, embedding_table = self.word_embedding(input_ids)
# If eod_reset disabled, there will be only one input from the dataset, i.e., input_ids
# and the corresponding input_position and attention_mask will be derived from it.
if not self.eod_reset:
batch_size, seq_length = F.shape(input_ids)
attention_mask = self.get_attention_mask(input_mask)
if self.use_past and not self.is_first_iteration:
input_position = valid_index.view(1, seq_length)
else:
input_position = F.tuple_to_array(F.make_range(seq_length))
input_position = P.Tile()(input_position, (batch_size, 1))
position_embedding = self.position_embedding(input_position)
# Input features [bs, seq_length, embedding_size]
hidden_states = self.add(input_embedding, position_embedding)
hidden_states = self.dropout(hidden_states)
hidden_states = P.Cast()(hidden_states, mstype.float16)
attention_mask = self.expand_dims(attention_mask, 1)
return hidden_states, attention_mask, embedding_table
hidden_states = self.embedding(input_ids, table, input_position, valid_index)
attention_mask = self.mask(attention_mask)
return hidden_states, attention_mask
class PanguAlpha_Model(nn.Cell):
"""
The backbone of PanguAlpha network
Args:
config(PanguAlphaConfig): the config of network
Inputs:
input_ids: the tokenized inputs with datatype int32
input_mask: the mask indicating whether each position is a valid input
input_position: the position index of each token
attention_mask: the attention_mask attention for self-attention module
init_reset: whether reset saved key and value states
batch_valid_length: the valid input sequence length without padding
Returns:
output_state: Tensor, the output logit of backbone
embedding_table: Tensor, the embedding table for the vocabulary
"""
def __init__(self, config):
super(PanguAlpha_Model, self).__init__()
self.embedding = Embedding(config)
self.blocks = nn.CellList()
# Total fusion groups for HCCL operators. Specifically, the same tyep HCCL operators in same group will be fused.
fusion_group_num = 4
fusion_group_size = config.num_layers // fusion_group_num
fusion_group_size = max(fusion_group_size, 1)
num_layers = config.num_layers
# If top_query_attention enabled, replace the last normal self-attention layers with this top_query_attention layer
if config.use_top_query_attention:
num_layers -= 1
self.num_layers = num_layers
print("After setting the layer is:", num_layers, flush=True)
for i in range(num_layers):
per_block = Decoder(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
# Each layer will be remoputed in the backward process. The output activation of each layer will be saved,
# in other words, in backward process each block will be almosttotally recomputed.
if config.use_recompute:
per_block.recompute()
self.blocks.append(per_block)
if config.self_layernorm:
self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(
mstype.float32).set_comm_fusion(
int((num_layers - 1) / fusion_group_size) + 2)
else:
self.layernorm = nn.LayerNorm((config.embedding_size,)).to_float(
mstype.float32).set_comm_fusion(
int((num_layers - 1) / fusion_group_size) + 2)
self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
self.layernorm.gamma.parallel_optimizer = False
self.layernorm.beta.parallel_optimizer = False
self.use_past = config.use_past
self.past = tuple([None] * config.num_layers)
self.dtype = config.compute_dtype
# If top_query_attention enabled, the input_position representing the position ids will be used as the index
# for a query embedding table to obtain top query hidden states, together with the previous outputs of normal
# self-attention layers, a new attention layer will be attached to the output of the model
if config.use_top_query_attention:
if config.load_ckpt_path:
# Loading the embedding table from the ckpt path:
embedding_path = os.path.join(config.load_ckpt_path, 'top_query_embedding.npy')
if os.path.exists(embedding_path):
top_query_table = np.load(embedding_path)
top_query_table_param = Tensor(top_query_table, mstype.float32)
else:
raise ValueError(
f"{embedding_path} file not exits, please check whether top_query_embedding file exist.")
else:
top_query_table_param = TruncatedNormal(0.02)
self.top_query_embedding = nn.Embedding(config.seq_length, config.embedding_size,
embedding_table=top_query_table_param)
# If the model is initialized with fp16, the fusion of layernorm (fp32 gradient) will mix up with
# the bias parameter in linear models (fp16 gradient), causing dtype error for communication operators.
# so we fuse communications of embedding to a large value(+100)
self.top_query_embedding.set_comm_fusion(int((config.num_layers - 1) / fusion_group_size) + 200)
self.top_query_embedding.embedding_table.parallel_optimizer = False
self.top_query_embedding.gather.shard(((1, 1), (config.dp,)))
self.top_query_embedding.expand.shard(((config.dp, 1),))
self.top_query_layer = QueryLayer(config)
if config.use_recompute:
self.top_query_layer.recompute()
self.top_query_layer.set_comm_fusion(int((config.num_layers - 1) / fusion_group_size) + 2)
self.use_top_query_attention = config.use_top_query_attention
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None,
init_reset=True, batch_valid_length=None):
"""PanguAlpha model"""
# embedding for input_ids and the lower triangle like attention_mask matrix
hidden_states, attention_mask, embedding_table = self.embedding(input_ids, input_mask,
input_position, attention_mask,
batch_valid_length)
# Loop through each self-attention layer
for i in range(self.num_layers):
hidden_states, _ = self.blocks[i](hidden_states,
attention_mask, init_reset, batch_valid_length)
output_state = self.layernorm(hidden_states)
output_state = F.cast(output_state, self.dtype)
# Top query attention layer
if self.use_top_query_attention:
top_query_hidden_states = self.top_query_embedding(input_position)
output_state, _ = self.top_query_layer(output_state, top_query_hidden_states,
attention_mask, init_reset, batch_valid_length)
return output_state, embedding_table
class PanguAlpha_ModelPipeline(nn.Cell):
"""
The backbone of PanguAlpha network
Args:
@ -1113,55 +899,92 @@ class PanguAlpha_ModelPipeline(nn.Cell):
embedding_table: Tensor, the embedding table for the vocabulary
"""
def __init__(self, config):
super(PanguAlpha_ModelPipeline, self).__init__()
self.pangu_alpha_embedding = PanguAlpha_EmbeddingPipeLine(config).set_comm_fusion(1)
self.pangu_alpha_embedding.pipeline_stage = 0
self.pangu_alpha_mask = PanguAlpha_Mask(config)
super(PanguAlpha_Model, self).__init__()
self.embedding = PanguAlphaEmbedding(config)
self.blocks = nn.CellList()
self.top_query_embedding = nn.Embedding(config.seq_length, config.embedding_size,
embedding_table=TruncatedNormal(0.02))
self.top_query_embedding.gather.shard(((1, 1), (config.dp,)))
self.top_query_embedding.expand.shard(((config.dp, 1),))
for i in range(config.num_layers):
if i == config.num_layers - 1:
self.top_query_embedding.set_comm_fusion(2)
self.top_query_embedding.pipeline_stage = i * config.stage_num // config.num_layers
per_block = QueryLayer(config).set_comm_fusion(2)
else:
per_block = Decoder(config, i + 1).set_comm_fusion(2)
per_block.pipeline_stage = i * config.stage_num // config.num_layers
per_block.recompute()
self.blocks.append(per_block)
if config.self_layernorm:
self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
else:
self.layernorm = nn.LayerNorm(
(config.embedding_size,)).to_float(mstype.float32)
self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
self.layernorm.set_comm_fusion(2)
self.layernorm.pipeline_stage = config.stage_num - 1
self.use_past = config.use_past
self.dtype = config.compute_dtype
self.num_layers = config.num_layers
self.is_pipeline = (config.stage_num > 1)
if self.is_pipeline:
self.top_query_embedding_table = Parameter(initializer(TruncatedNormal(0.02),
[config.vocab_size, config.embedding_size]),
name='embedding_table', parallel_optimizer=False)
self.top_query_embedding = EmbeddingLookup()
for i in range(config.num_layers):
if i == config.num_layers - 1:
self.top_query_embedding_table.comm_fusion = 2
self.top_query_embedding_table.add_pipeline_stage(i * config.stage_num // config.num_layers)
per_block = QueryLayer(config).set_comm_fusion(2)
else:
per_block = Decoder(config, i + 1).set_comm_fusion(2)
per_block.pipeline_stage = i * config.stage_num // config.num_layers
per_block.recompute()
self.blocks.append(per_block)
self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
self.layernorm.set_comm_fusion(2)
self.layernorm.pipeline_stage = config.stage_num - 1
else:
# The input_position representing the position ids will be used as the index
# for a query embedding table to obtain top query hidden states, together with the previous outputs of normal
# self-attention layers, a new attention layer will be attached to the output of the model
if config.load_ckpt_path:
# Loading the embedding table from the ckpt path:
embedding_path = os.path.join(config.load_ckpt_path, 'top_query_embedding.npy')
if os.path.exists(embedding_path):
top_query_table = np.load(embedding_path)
top_query_table_param = Tensor(top_query_table, mstype.float32)
else:
raise ValueError(
f"{embedding_path} file not exits, please check whether top_query_embedding file exist.")
else:
top_query_table_param = TruncatedNormal(0.02)
self.top_query_embedding_table = Parameter(initializer(top_query_table_param,
[config.vocab_size, config.embedding_size]),
name='embedding_table', parallel_optimizer=False)
self.top_query_embedding = EmbeddingLookup()
# Total fusion groups for HCCL operators. Specifically, the same tyep HCCL operators in same group will be fused.
fusion_group_num = 4
fusion_group_size = config.num_layers // fusion_group_num
fusion_group_size = max(fusion_group_size, 1)
for i in range(config.num_layers):
if i == config.num_layers - 1:
per_block = QueryLayer(config).set_comm_fusion(int(i / fusion_group_size) + 2)
else:
per_block = Decoder(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
# Each layer will be remoputed in the backward process. The output activation of each layer will be saved,
# in other words, in backward process each block will be almosttotally recomputed.
per_block.recompute()
self.blocks.append(per_block)
self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
self.layernorm.set_comm_fusion(int((config.num_layers - 1) / fusion_group_size) + 2)
self.top_query_embedding_table.comm_fusion = int((config.num_layers - 1) / fusion_group_size) + 2
self.top_query_embedding.gather.shard(((1, 1), (config.dp,)))
def construct(self, input_ids, input_mask, table, input_position, attention_mask,
init_reset=True, batch_valid_length=None):
"""PanguAlpha model"""
hidden_states = self.pangu_alpha_embedding(input_ids, table, input_position)
attention_mask = self.pangu_alpha_mask(input_mask, attention_mask)
# embedding for input_ids and the lower triangle like attention_mask matrix
hidden_states, attention_mask = self.embedding(input_ids, input_mask, table,
input_position, attention_mask,
batch_valid_length)
for i in range(self.num_layers-1):
hidden_states, _ = self.blocks[i](hidden_states,
attention_mask, init_reset, batch_valid_length)
top_query_hidden_states = self.top_query_embedding(input_position)
hidden_states, present_layer = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states,
attention_mask, init_reset, batch_valid_length)
output_state = self.layernorm(hidden_states)
output_state = F.cast(output_state, self.dtype)
return output_state, present_layer
if self.is_pipeline:
top_query_hidden_states = self.top_query_embedding(input_position.view(-1,), self.top_query_embedding_table)
hidden_states, _ = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states,
attention_mask, init_reset, batch_valid_length)
output_state = self.layernorm(hidden_states)
output_state = F.cast(output_state, self.dtype)
else:
output_state = self.layernorm(hidden_states)
output_state = F.cast(output_state, self.dtype)
top_query_hidden_states = self.top_query_embedding(input_position.view(-1,), self.top_query_embedding_table)
output_state, _ = self.blocks[self.num_layers-1](output_state, top_query_hidden_states,
attention_mask, init_reset, batch_valid_length)
return output_state
class PanguAlpha_Head(nn.Cell):
"""
@ -1193,35 +1016,6 @@ class PanguAlpha_Head(nn.Cell):
class PanguAlpha(nn.Cell):
"""
The PanguAlpha network consisting of two parts the backbone and the head
Args:
config(PanguAlphaConfig): the config of network
Inputs:
input_ids: the tokenized inputs
input_mask: the mask indicating whether each position is a valid input
input_position: the position index of each token
attention_mask: the attention_mask attention for self-attention module
init_reset: whether reset saved key and value states
batch_valid_length: the valid input sequence length without padding
Returns:
logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
"""
def __init__(self, config):
super(PanguAlpha, self).__init__()
# Network backbone of PanguAlpha
self.backbone = PanguAlpha_Model(config)
# Network head to get logits over vocabulary
self.head = PanguAlpha_Head(config)
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None,
init_reset=True, batch_valid_length=None):
output_states, embedding_table = self.backbone(
input_ids, input_mask, input_position, attention_mask, init_reset, batch_valid_length)
logits = self.head(output_states, embedding_table)
return logits
class PanguAlphaPipeline(nn.Cell):
"""
The PanguAlpha network consisting of two parts the backbone and the head
Args:
@ -1234,21 +1028,38 @@ class PanguAlphaPipeline(nn.Cell):
logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
"""
def __init__(self, config):
super(PanguAlphaPipeline, self).__init__()
self.backbone = PanguAlpha_ModelPipeline(config)
super(PanguAlpha, self).__init__()
# Network head to get logits over vocabulary
self.head = PanguAlpha_Head(config)
self.head.pipeline_stage = config.stage_num - 1
self.vocab_size = config.vocab_size
self.embedding_size = config.embedding_size
self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]),
name="embedding_table")
self.embedding_table.add_pipeline_stage(self.backbone.blocks[0].pipeline_stage)
self.embedding_table.add_pipeline_stage(self.head.pipeline_stage)
self.backbone = PanguAlpha_Model(config)
if config.stage_num > 1:
self.head.pipeline_stage = config.stage_num - 1
self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]),
name="embedding_table", parallel_optimizer=False)
self.embedding_table.add_pipeline_stage(self.backbone.blocks[0].pipeline_stage)
self.embedding_table.add_pipeline_stage(self.head.pipeline_stage)
else:
if config.load_ckpt_path:
# Loading the embedding table from the ckpt path:
embedding_path = os.path.join(config.load_ckpt_path, 'word_embedding.npy')
if os.path.exists(embedding_path):
e_table = np.load(embedding_path)
e_table = Tensor(e_table, mstype.float32)
self.embedding_table = Parameter(e_table, name="embedding_table", parallel_optimizer=False)
else:
raise ValueError(f"{embedding_path} file not exits, "
f"please check whether word_embedding file exist.")
else:
self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]),
name="embedding_table", parallel_optimizer=False)
def construct(self, input_ids, input_mask, input_position, attention_mask,
init_reset=True, batch_valid_length=None):
output_states, _ = self.backbone(input_ids, input_mask, self.embedding_table,
input_position, attention_mask, init_reset, batch_valid_length)
output_states = self.backbone(input_ids, input_mask, self.embedding_table,
input_position, attention_mask, init_reset, batch_valid_length)
logits = self.head(output_states, self.embedding_table)
return logits
@ -1324,7 +1135,6 @@ class CrossEntropyLoss(nn.Cell):
loss = self.div2(numerator, denominator)
return loss
class PanguAlphaWithLoss(nn.Cell):
"""
PanguAlpha training loss
@ -1342,65 +1152,14 @@ class PanguAlphaWithLoss(nn.Cell):
super(PanguAlphaWithLoss, self).__init__(auto_prefix=False)
self.network = network
self.loss = loss
# id for end_of_sentence, 6 in the vocabulary
self.eos_token = eos_token
self.slice = P.StridedSlice().shard(((config.dp, 1),))
self.not_equal = P.NotEqual().shard(((config.dp, 1), ()))
self.batch_size = config.batch_size
self.len = config.seq_length
self.eod_reset = config.eod_reset
if self.eod_reset:
self.slice_mask = P.StridedSlice().shard(((config.dp, 1, 1),))
def construct(self, input_ids, input_position=None, attention_mask=None):
r"""
PanguAlphaWithLoss
"""
# input_ids [bs, seq_length+1]
# input_position [bs, seq_length] only available when eod_reset enabled
# attention_mask [bs, seq_length, seq_length] only available when eod-reset enabled
# Get input tokens [bs, seq_length]
tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1))
if self.eod_reset:
input_position = self.slice(input_position, (0, 0), (self.batch_size, self.len), (1, 1))
attention_mask = self.slice_mask(attention_mask, (0, 0, 0),
(self.batch_size, self.len, self.len),
(1, 1, 1))
# Check whether there is padding in inputs
input_mask = F.cast(self.not_equal(tokens, self.eos_token),
mstype.float32)
logits = self.network(tokens, input_mask, input_position, attention_mask)
# Get label corresponding to input tokens
labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1),
(1, 1))
# Loss
output = self.loss(logits, labels, input_mask)
return output
class PanguAlphaWithLossPipeline(nn.Cell):
"""
PanguAlpha training loss
Args:
network: backbone network of PanguAlpha
loss: loss function, e.g., crossentropy
eos_token: the end_of_sentence token
Inputs:
input_ids: the tokenized inputs
past: the previous feature map
Returns:
output: Tensor, the loss of the network
"""
def __init__(self, config, network, loss, eos_token=6):
super(PanguAlphaWithLossPipeline, self).__init__(auto_prefix=False)
self.network = network
self.loss = loss
self.eos_token = eos_token
self.slice = P.StridedSlice().shard(((config.dp, 1),))
self.not_equal = P.NotEqual().shard(((config.dp, 1), ()))
self.batch_size = config.batch_size
self.len = config.seq_length
self.micro_batch_step = config.micro_size
self.micro_batch_step = 1
if config.stage_num > 1:
self.micro_batch_step = config.micro_size
def construct(self, input_ids, input_position, attention_mask):
tokens = self.slice(input_ids, (0, 0), (self.batch_size // self.micro_batch_step, -1), (1, 1))

View File

@ -36,15 +36,13 @@ class PANGUALPHAConfig:
dropout_rate=0.1,
compute_dtype=mstype.float16,
use_past=False,
self_layernorm=True,
word_emb_dp=True,
stage_num=16,
eod_reset=True,
micro_size=32,
load_ckpt_path=None,
use_top_query_attention=True,
param_init_type=mstype.float32,
use_recompute=True):
param_init_type=mstype.float32):
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
@ -61,8 +59,6 @@ class PANGUALPHAConfig:
self.use_past = use_past
self.dp = data_parallel_num
self.mp = model_parallel_num
# Whether use self implemented layernorm
self.self_layernorm = self_layernorm
self.stage_num = stage_num
self.micro_size = micro_size
self.word_emb_dp = word_emb_dp
@ -70,7 +66,6 @@ class PANGUALPHAConfig:
# Used for loading embedding tables
self.load_ckpt_path = load_ckpt_path
self.use_top_query_attention = use_top_query_attention
self.use_recompute = use_recompute
self.param_init_type = param_init_type
def __str__(self):
@ -89,7 +84,8 @@ def set_parse(args_opt):
args_opt.embedding_size = 16384
args_opt.num_layers = 64
args_opt.num_heads = 128
args_opt.per_batch_size = 1
if args_opt.per_batch_size == 0:
args_opt.per_batch_size = 1
args_opt.word_emb_dp = 0
if args_opt.run_type == "train":
args_opt.start_lr = 6e-5
@ -117,9 +113,14 @@ def set_parse(args_opt):
args_opt.optimizer_shard = 1
args_opt.stage_num = 1
args_opt.micro_size = 1
args_opt.full_batch = 0
if args_opt.per_batch_size == 0:
args_opt.per_batch_size = 8
elif args_opt.run_type == "predict":
args_opt.stage_num = 1
args_opt.micro_size = 1
if args_opt.per_batch_size == 0:
args_opt.per_batch_size = 1
elif args_opt.mode == "2.6B":
args_opt.embedding_size = 2560
args_opt.num_layers = 32
@ -131,6 +132,11 @@ def set_parse(args_opt):
args_opt.optimizer_shard = 1
args_opt.stage_num = 1
args_opt.micro_size = 1
args_opt.full_batch = 0
if args_opt.per_batch_size == 0:
args_opt.per_batch_size = 8
elif args_opt.run_type == "predict":
args_opt.stage_num = 1
args_opt.micro_size = 1
if args_opt.per_batch_size == 0:
args_opt.per_batch_size = 1

View File

@ -25,6 +25,7 @@ from mindspore import context, Parameter
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size
from mindspore.parallel._utils import _get_enable_parallel_optimizer
from src.utils import ClipByGlobalNorm
GRADIENT_CLIP_TYPE = 1
@ -61,6 +62,7 @@ def _clip_grad(clip_type, clip_value, grad):
grad_scale = C.MultitypeFuncGraph("grad_scale")
shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale")
reciprocal = P.Reciprocal()
@ -77,6 +79,13 @@ def tensor_grad_scale_pipeline(scale, grad, accu_grad):
_ = F.assign(accu_grad, zeros)
return new_grad
@shard_grad_scale.register("Tensor", "Tensor", "Tensor")
def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
new_grad = grad * reciprocal(scale)
accu_grad = F.depend(accu_grad, new_grad)
_ = F.assign(accu_grad, F.zeros_like(accu_grad))
return new_grad
class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
"""
Encapsulation class of PanguAlpha network training.
@ -106,7 +115,7 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
self.clip = ClipByGlobalNorm(self.weights, config)
self.cast = P.Cast()
def construct(self, input_ids, input_position=None, attention_mask=None, layer_past=None, sens=None):
def construct(self, input_ids, input_position, attention_mask, layer_past=None, sens=None):
"""Defines the computation performed."""
weights = self.weights
# Forward process
@ -194,6 +203,7 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell):
name="loss_scale")
self.clip = ClipByGlobalNorm(self.weights, self.config)
self.micro_size = config.micro_size
self.opt_shard = _get_enable_parallel_optimizer()
@C.add_flags(has_effect=True)
def construct(self,
@ -224,8 +234,12 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell):
flag_sum = self.reduce_sum(init, (0,))
loss = F.depend(loss, status_clear)
# apply grad reducer on grads
accu_grads = self.grad_reducer(self.accu_grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads)
if self.opt_shard:
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads)
else:
accu_grads = self.grad_reducer(self.accu_grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads)
if self.enable_global_norm:
grads, _ = self.clip(grads)
else:

View File

@ -330,7 +330,7 @@ def add_training_params(opt):
help="Enable optimizer parallel, default is 1")
opt.add_argument("--per_batch_size",
type=int,
default=6,
default=0,
help="The batch size for each data parallel way. default 6")
opt.add_argument("--start_lr",
type=float,

View File

@ -31,8 +31,7 @@ from mindspore.parallel import set_algo_parameters
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell
from src.dataset import create_dataset
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss,\
PanguAlphaPipeline, PanguAlphaWithLossPipeline, CrossEntropyLoss
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell
from src.pangu_alpha_config import PANGUALPHAConfig, set_parse
from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
@ -196,10 +195,7 @@ def run_train_pipeline(args_opt):
The main training process in pipeline.
"""
device_id = int(os.getenv("DEVICE_ID"))
context.set_context(save_graphs=False,
mode=context.GRAPH_MODE,
device_target="Ascend",
device_id=device_id)
context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
context.set_context(variable_memory_max_size="31GB")
if args_opt.distribute == "true":
D.init()
@ -210,7 +206,7 @@ def run_train_pipeline(args_opt):
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
gradients_mean=False,
device_num=device_num,
full_batch=True,
full_batch=bool(args_opt.full_batch),
loss_repeated_mean=True,
enable_parallel_optimizer=bool(args_opt.optimizer_shard),
pipeline_stages=args_opt.stage_num)
@ -219,6 +215,12 @@ def run_train_pipeline(args_opt):
else:
rank_id = int(os.getenv("RANK_ID"))
device_num = 1
# copy data from the cloud to the /cache/Data
cache_url = '/cache/Data/'
if args_opt.offline:
cache_url = args_opt.data_url
else:
download_data(src_data_url=args_opt.data_url, tgt_data_path=cache_url, rank=rank_id)
model_parallel_num = args_opt.op_level_model_parallel_num
stage_device_num = int(device_num / args_opt.stage_num)
data_parallel_num = int(stage_device_num / model_parallel_num)
@ -238,20 +240,17 @@ def run_train_pipeline(args_opt):
dropout_rate=0.1,
compute_dtype=mstype.float16,
use_past=False,
self_layernorm=True,
stage_num=args_opt.stage_num,
micro_size=args_opt.micro_size,
word_emb_dp=bool(args_opt.word_emb_dp))
print("===config is: ", config, flush=True)
pangu_alpha = PanguAlphaPipeline(config)
pangu_alpha = PanguAlpha(config)
loss = CrossEntropyLoss(config)
pangu_alpha_with_loss = PipelineCell(PanguAlphaWithLossPipeline(config, pangu_alpha, loss), config.micro_size)
pangu_alpha_with_loss = PipelineCell(PanguAlphaWithLoss(config, pangu_alpha, loss), config.micro_size)
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss)
print("=====args_opt is: ", args_opt, flush=True)
lr = LearningRate(learning_rate=args_opt.start_lr,
end_learning_rate=args_opt.end_lr,
warmup_steps=args_opt.warmup_step,
decay_steps=args_opt.decay_steps)
lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr,
warmup_steps=args_opt.warmup_step, decay_steps=args_opt.decay_steps)
params = pangu_alpha.infer_param_pipeline_stage()
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
decay_params = list(filter(decay_filter, params))
@ -269,32 +268,33 @@ def run_train_pipeline(args_opt):
optimizer = nn.Lamb(group_params, learning_rate=lr)
else:
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
ds = create_dataset(config.batch_size, data_path=args_opt.data_url, eod_reset=True,
data_start_index=0, full_batch=True, column_name=args_opt.data_column_name)
if context.get_auto_parallel_context("full_batch"):
ds = create_dataset(config.batch_size, data_path=cache_url, eod_reset=True,
data_start_index=0, full_batch=True, column_name=args_opt.data_column_name)
else:
if batch_size % stage_device_num != 0:
raise ValueError("Batch_size should be divisible by device_num")
ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
rank=rank_id, eod_reset=True, data_start_index=0, full_batch=False,
column_name=args_opt.data_column_name)
epoch_num = args_opt.epoch_size
step_per_epoch = ds.get_dataset_size()
callback_size = args_opt.sink_size
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
callback = [
TimeMonitor(callback_size),
LossCallBack(callback_size, rank_id, config.stage_num)
]
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, config.stage_num)]
loss_scale_value = math.pow(2, 32)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value,
scale_factor=2,
scale_window=1000)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell)
model = Model(pangu_alpha_with_grads)
model.train(actual_epoch_num,
ds,
callbacks=callback,
sink_size=callback_size,
dataset_sink_mode=True)
model.train(actual_epoch_num, ds, callbacks=callback,
sink_size=callback_size, dataset_sink_mode=True)
if __name__ == "__main__":
opt = get_args()
set_parse(opt)
if opt.per_batch_size == 0:
raise ValueError("The per_batch_size has not been configured.")
if opt.stage_num > 1:
run_train_pipeline(opt)
else: