forked from mindspore-Ecosystem/mindspore
!16355 add comments for PanguAlpha
From: @alouhahahahaha Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsuteng,@stsuteng
This commit is contained in:
commit
f71a490400
|
@ -43,11 +43,13 @@ def run_predict(args_opt):
|
|||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
local_rank = rank_id
|
||||
print('local_rank:{}, device id:{} start to run...'.format(local_rank, device_id), flush=True)
|
||||
# Set execution mode
|
||||
context.set_context(save_graphs=False,
|
||||
mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
device_id=device_id)
|
||||
context.set_context(variable_memory_max_size="30GB")
|
||||
# Set parallel context
|
||||
if args_opt.distribute == "true":
|
||||
D.init()
|
||||
device_num = D.get_group_size()
|
||||
|
@ -71,6 +73,7 @@ def run_predict(args_opt):
|
|||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
# Set model property
|
||||
model_parallel_num = args_opt.tensor_model_parallel_num
|
||||
data_parallel_num = int(device_num / model_parallel_num)
|
||||
per_batch_size = args_opt.per_batch_size
|
||||
|
@ -99,10 +102,13 @@ def run_predict(args_opt):
|
|||
print("=====args_opt is: ", args_opt, flush=True)
|
||||
|
||||
ckpt_name = args_opt.load_ckpt_name
|
||||
# Define network
|
||||
pangu_alpha = PanguAlpha(config)
|
||||
eval_net = EvalNet(pangu_alpha)
|
||||
eval_net.set_train(False)
|
||||
model_predict = Model(eval_net)
|
||||
|
||||
# Compile network and obtain tensor layout for loading ckpt
|
||||
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
|
||||
predict_layout = model_predict.infer_predict_layout(inputs_np)
|
||||
print("======start load_distributed checkpoint", flush=True)
|
||||
|
@ -111,19 +117,24 @@ def run_predict(args_opt):
|
|||
ckpt_file_list = [os.path.join(args_opt.load_ckpt_path, f"{ckpt_name}_{ckpt_rank}.ckpt") for ckpt_rank in
|
||||
range(0, 512)]
|
||||
print(f"Loading from path {ckpt_file_list[0]}", flush=True)
|
||||
# Load checkpoint files
|
||||
load_distributed_checkpoint(eval_net, ckpt_file_list, predict_layout)
|
||||
print("================load param ok=================", flush=True)
|
||||
|
||||
from src.tokenization_jieba import JIEBATokenizer
|
||||
from src.generate import generate
|
||||
# Define tokenizer
|
||||
tokenizer = JIEBATokenizer(os.path.join(args_opt.tokenizer_path, 'vocab10.vocab'),
|
||||
os.path.join(args_opt.tokenizer_path, 'vocab10.model'))
|
||||
|
||||
# Tokenize input sentence to ids
|
||||
sample = "今天是一个好天气"
|
||||
tokenized_token = tokenizer.tokenize(sample)
|
||||
start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token)
|
||||
input_ids = np.array(start_sentence).reshape(1, -1)
|
||||
# Call inference
|
||||
output_ids = generate(model_predict, input_ids, config.seq_length, 9)
|
||||
# Decode output ids to sentence
|
||||
output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist())
|
||||
print('Output is:', output_samples, flush=True)
|
||||
|
||||
|
|
|
@ -40,16 +40,22 @@ def get_input_data(input_ids, eod_id, rank, dis):
|
|||
input_ids = input_ids[rank*dis: (rank+1)*dis]
|
||||
seq_length = input_ids.shape[1] - 1
|
||||
|
||||
# Initialize position_ids and attention_mask
|
||||
batch_input_ids = input_ids
|
||||
batch_position_ids = np.ones((dis, seq_length))
|
||||
batch_attention_mask = np.ones((dis, seq_length, seq_length))
|
||||
|
||||
# Loop through batches
|
||||
for bs_i, _ in enumerate(range(len(input_ids))):
|
||||
# Get normal position_ids and attention_mask
|
||||
local_ids = input_ids[bs_i]
|
||||
batch_attention_mask[bs_i] = np.tril(np.ones(shape=(seq_length, seq_length)))
|
||||
batch_position_ids[bs_i] = np.arange(seq_length)
|
||||
# Find eod_of_document
|
||||
eod_index = batch_position_ids[bs_i, local_ids[:-1] == eod_id].astype(np.int32)
|
||||
prev_index = 0
|
||||
for i in range(eod_index.size):
|
||||
# Reset position_ids and attention_mask considering EOD
|
||||
index = eod_index[i]
|
||||
batch_attention_mask[bs_i, (index+1):, :(index+1)] = 0
|
||||
batch_position_ids[bs_i, (index+1):] -= (index + 1 - prev_index)
|
||||
|
@ -76,6 +82,8 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_
|
|||
dataset_restore: the dataset for training or evaluating
|
||||
"""
|
||||
ds.config.set_seed(1)
|
||||
|
||||
# Get path for source data files
|
||||
home_path = os.path.join(os.getcwd(), data_path)
|
||||
files = os.listdir(data_path)
|
||||
dis = int(batch_size / device_num)
|
||||
|
@ -89,9 +97,12 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_
|
|||
if not name.endswith(".db")
|
||||
]
|
||||
|
||||
# Load data files and preprocess
|
||||
dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False)
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
type_cast_op_float = C.TypeCast(mstype.float16)
|
||||
|
||||
# If eod_reset enabled, another two inputs will be generated through input_ids
|
||||
if eod_reset:
|
||||
map_func = (lambda input_ids: get_input_data(input_ids, eod_id, rank, dis))
|
||||
dataset = dataset.batch(batch_size, drop_remainder=drop)
|
||||
|
|
|
@ -37,22 +37,30 @@ def generate(model, origin_inputs, seq_length, end_token=50256):
|
|||
seq_length = seq_length
|
||||
_, valid_length = origin_inputs.shape
|
||||
pad_length = seq_length - origin_inputs.shape[-1]
|
||||
# Pad original inputs to seq_length
|
||||
input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
|
||||
print("input_ids is ", input_ids)
|
||||
|
||||
# A single loop generates one token, loop until reaching target seq_length or generating eod token
|
||||
while valid_length < seq_length:
|
||||
inputs = Tensor(input_ids, mstype.int32)
|
||||
# Call a single inference
|
||||
probs, p_args = model.predict(inputs)
|
||||
# Get the topk value and index for the current position from the padded outputs
|
||||
probs = probs.asnumpy()[valid_length-1, :]
|
||||
p_args = p_args.asnumpy()[valid_length-1, :]
|
||||
|
||||
# Random select a token as final output for this round
|
||||
p = probs
|
||||
p = p / sum(p)
|
||||
target_index = np.random.choice(len(p), p=p)
|
||||
# Stop judgment
|
||||
if p_args[target_index] == end_token or valid_length == seq_length-1:
|
||||
outputs = input_ids
|
||||
break
|
||||
# Modify input_ids with newly generated token
|
||||
input_ids[0][valid_length] = p_args[target_index]
|
||||
valid_length += 1
|
||||
# Return valid outputs out of padded outputs
|
||||
length = np.sum(outputs != 0)
|
||||
outputs = outputs[0][:length]
|
||||
return outputs
|
||||
|
|
|
@ -203,7 +203,9 @@ class Output(nn.Cell):
|
|||
super(Output, self).__init__()
|
||||
input_size = config.embedding_size
|
||||
output_size = config.embedding_size * config.expand_ratio
|
||||
# Project to expand_ratio*embedding_size
|
||||
self.mapping = Mapping_output(config, input_size, output_size)
|
||||
# Project back to embedding_size
|
||||
self.projection = Mapping(config, output_size, input_size, scale)
|
||||
self.activation = nn.GELU()
|
||||
self.activation.gelu.shard(((config.dp, 1, config.mp),))
|
||||
|
@ -212,8 +214,10 @@ class Output(nn.Cell):
|
|||
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
|
||||
|
||||
def construct(self, x):
|
||||
# [bs, seq_length, expand_ratio*embedding_size]
|
||||
hidden = self.activation(self.mapping(x))
|
||||
output = self.projection(hidden)
|
||||
# [bs, seq_length, expand_ratio]
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
|
@ -235,6 +239,7 @@ class AttentionMask(nn.Cell):
|
|||
((config.dp, 1, 1), (config.dp, 1, 1))) # yzz: use 64, 1, 1?
|
||||
self.expand_dim = P.ExpandDims().shard(((1, 1),))
|
||||
ones = np.ones(shape=(config.seq_length, config.seq_length))
|
||||
# Default lower triangle mask matrix
|
||||
self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
|
||||
self.multiply = P.Mul().shard(((config.dp, 1, 1), (1, 1, 1)))
|
||||
|
||||
|
@ -245,12 +250,14 @@ class AttentionMask(nn.Cell):
|
|||
input_shape = P.Shape()(input_mask)
|
||||
shape_right = (input_shape[0], 1, input_shape[1])
|
||||
shape_left = input_shape + (1,)
|
||||
# Mask the padded inputs
|
||||
mask_left = self.reshape(input_mask, shape_left)
|
||||
mask_right = self.reshape(input_mask, shape_right)
|
||||
attention_mask = self.mul(mask_left, mask_right)
|
||||
lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0)
|
||||
# [bs, seq_length, seq_length]
|
||||
attention_mask = self.multiply(
|
||||
attention_mask, lower_traiangle) #bs seq_length seq_length
|
||||
attention_mask, lower_traiangle)
|
||||
return attention_mask
|
||||
|
||||
|
||||
|
@ -305,7 +312,9 @@ class Attention(nn.Cell):
|
|||
"""
|
||||
def __init__(self, config, scale=1.0, layer_idx=None):
|
||||
super(Attention, self).__init__()
|
||||
# Attention mask matrix
|
||||
self.get_attention_mask = AttentionMask(config)
|
||||
# Output layer
|
||||
self.projection = Mapping(config, config.embedding_size,
|
||||
config.embedding_size, scale)
|
||||
self.transpose = P.Transpose().shard(((config.dp, 1, config.mp, 1),))
|
||||
|
@ -313,6 +322,7 @@ class Attention(nn.Cell):
|
|||
((config.dp, config.mp, 1, 1),))
|
||||
self.reshape = P.Reshape()
|
||||
self.n_head = config.num_heads
|
||||
# embedding size per head
|
||||
self.size_per_head = config.embedding_size // self.n_head
|
||||
self.concat_k = P.Concat(axis=3)
|
||||
self.concat_v = P.Concat(axis=2)
|
||||
|
@ -329,6 +339,7 @@ class Attention(nn.Cell):
|
|||
((config.dp, 1, 1, 1), (1,)))
|
||||
self.add = P.TensorAdd().shard(
|
||||
((config.dp, 1, 1, 1), (config.dp, config.mp, 1, 1)))
|
||||
# Normalize factor for attention, sqrt(dk) as widely used
|
||||
if self.scale:
|
||||
self.scale_factor = Tensor(math.sqrt(self.size_per_head))
|
||||
if layer_idx is not None:
|
||||
|
@ -347,16 +358,19 @@ class Attention(nn.Cell):
|
|||
self.softmax.softmax.shard(((config.dp, config.mp, 1),))
|
||||
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
||||
|
||||
# Query
|
||||
self.dense1 = nn.Dense(config.embedding_size,
|
||||
config.embedding_size).to_float(
|
||||
config.compute_dtype)
|
||||
self.dense1.matmul.shard(((config.dp, 1), (config.mp, 1)))
|
||||
self.dense1.bias_add.shard(((config.dp, config.mp), (config.mp,)))
|
||||
# Key
|
||||
self.dense2 = nn.Dense(config.embedding_size,
|
||||
config.embedding_size).to_float(
|
||||
config.compute_dtype)
|
||||
self.dense2.matmul.shard(((config.dp, 1), (config.mp, 1)))
|
||||
self.dense2.bias_add.shard(((config.dp, config.mp), (config.mp,)))
|
||||
# Value
|
||||
self.dense3 = nn.Dense(config.embedding_size,
|
||||
config.embedding_size).to_float(
|
||||
config.compute_dtype)
|
||||
|
@ -380,18 +394,22 @@ class Attention(nn.Cell):
|
|||
|
||||
original_shape = F.shape(x)
|
||||
x = F.reshape(x, (-1, original_shape[-1]))
|
||||
# Self attention: query, key, value are derived from the same inputs
|
||||
query = self.dense1(x)
|
||||
key = self.dense2(x)
|
||||
value = self.dense3(x)
|
||||
# [bs, num_heads, seq_length, size_per_head]
|
||||
query = self.transpose(
|
||||
F.reshape(
|
||||
query,
|
||||
(-1, original_shape[1], self.n_head, self.size_per_head)),
|
||||
(0, 2, 1, 3))
|
||||
# [bs, num_heads, size_per_head, seq_length]
|
||||
key = self.transpose(
|
||||
F.reshape(
|
||||
key, (-1, original_shape[1], self.n_head, self.size_per_head)),
|
||||
(0, 2, 3, 1))
|
||||
# [bs, num_heads, seq_length, size_per_head]
|
||||
value = self.transpose(
|
||||
F.reshape(
|
||||
value,
|
||||
|
@ -403,8 +421,11 @@ class Attention(nn.Cell):
|
|||
key = self.concat_k((past_key, key))
|
||||
value = self.concat_v(past_value, value)
|
||||
layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value])
|
||||
# Self-attention considering attention mask
|
||||
attention = self._attn(query, key, value, attention_mask)
|
||||
# [bs, seq_length, embedding_size]
|
||||
attention_merge = self.merge_heads(attention)
|
||||
# Output
|
||||
output = self.projection(attention_merge)
|
||||
output = self.dropout(output)
|
||||
return output, layer_present
|
||||
|
@ -454,11 +475,14 @@ class Attention(nn.Cell):
|
|||
Returns:
|
||||
weighted_values: Tensor, the weighted sum scores
|
||||
"""
|
||||
# Normalize query and key before MatMul, default off
|
||||
if not self.scale:
|
||||
query = query / F.cast(self.coeff, F.dtype(query))
|
||||
key = key / F.cast(self.coeff, F.dtype(key))
|
||||
|
||||
# Attention score [bs, num_heads, seq_length, seq_length]
|
||||
score = self.batch_matmul(query, key)
|
||||
# Normalize after query and key MatMul, default on
|
||||
if self.scale:
|
||||
score = self.real_div(
|
||||
score,
|
||||
|
@ -466,6 +490,7 @@ class Attention(nn.Cell):
|
|||
|
||||
ori_dtype = P.DType()(score)
|
||||
score = P.Cast()(score, mstype.float32)
|
||||
# Minus 10000 for the position where masked to exclude them from softmax
|
||||
multiplu_out = self.sub(
|
||||
P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
|
||||
P.Cast()(attention_mask, P.DType()(score)))
|
||||
|
@ -474,13 +499,15 @@ class Attention(nn.Cell):
|
|||
attention_scores = self.add(adder, score)
|
||||
|
||||
shape = F.shape(attention_scores)
|
||||
# attention probs
|
||||
attention_probs = self.softmax(
|
||||
F.reshape(attention_scores,
|
||||
(shape[0], -1, shape[-1]))) # yzz modify
|
||||
(shape[0], -1, shape[-1])))
|
||||
attention_probs = P.Cast()(attention_probs, ori_dtype)
|
||||
attention_probs = F.reshape(attention_probs, shape)
|
||||
|
||||
attention_probs = self.prob_dropout(attention_probs)
|
||||
# Weighted sum output [bs, num_heads, seq_length, size_per_head]
|
||||
weighted_values = self.batch_matmul(attention_probs, value)
|
||||
return weighted_values
|
||||
|
||||
|
@ -517,11 +544,13 @@ class Block(nn.Cell):
|
|||
self.attention = Attention(config, scale, layer_idx)
|
||||
self.layernorm2.gamma.parallel_optimizer = False
|
||||
self.layernorm2.beta.parallel_optimizer = False
|
||||
# Feed Forward Network, FFN
|
||||
self.output = Output(config, scale)
|
||||
self.post_layernorm_residual = config.post_layernorm_residual
|
||||
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||
self.last_add = P.TensorAdd().shard(
|
||||
((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||
# Last activation of this layer will be saved for recompute in backward process
|
||||
self.last_add.recompute(False)
|
||||
self.dtype = config.compute_dtype
|
||||
|
||||
|
@ -529,12 +558,15 @@ class Block(nn.Cell):
|
|||
r"""
|
||||
The forward process of the block.
|
||||
"""
|
||||
# [bs, seq_length, embedding_size]
|
||||
input_x = self.layernorm1(x)
|
||||
input_x = F.cast(input_x, self.dtype)
|
||||
attention, layer_present = self.attention(input_x, input_mask,
|
||||
layer_past)
|
||||
# For post-layernorm the inputs for residual path are output of self-attention and output of layernorm
|
||||
if self.post_layernorm_residual:
|
||||
x = self.add(input_x, attention)
|
||||
# For pre-layernorm the inputs for residual path are output of self-attention and input of this layer
|
||||
else:
|
||||
x = self.add(x, attention)
|
||||
|
||||
|
@ -556,6 +588,7 @@ class QueryLayerAttention(Attention):
|
|||
original_shape = F.shape(x)
|
||||
x = F.reshape(x, (-1, original_shape[-1]))
|
||||
query_hidden_state = F.reshape(query_hidden_state, (-1, original_shape[-1]))
|
||||
# For query_layer_attention, query are derived from outputs of previous layer and key, value are derived from an added parameter query_embedding
|
||||
query = self.dense1(query_hidden_state)
|
||||
key = self.dense2(x)
|
||||
value = self.dense3(x)
|
||||
|
@ -611,7 +644,8 @@ class QueryLayer(nn.Cell):
|
|||
|
||||
def construct(self, x, query_hidden_state, input_mask, layer_past=None):
|
||||
r"""
|
||||
Query Layer.
|
||||
Query Layer shares a similar structure with normal layer block
|
||||
except that it is not a traditional self-attention.
|
||||
"""
|
||||
input_x = self.layernorm1(x)
|
||||
input_x = F.cast(input_x, self.dtype)
|
||||
|
@ -650,6 +684,7 @@ class PanguAlpha_Model(nn.Cell):
|
|||
def __init__(self, config):
|
||||
super(PanguAlpha_Model, self).__init__()
|
||||
self.get_attention_mask = AttentionMask(config)
|
||||
# Word embedding
|
||||
self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1)
|
||||
if config.load_ckpt_path:
|
||||
# Loading the embedding table from the ckpt path:
|
||||
|
@ -662,6 +697,7 @@ class PanguAlpha_Model(nn.Cell):
|
|||
else:
|
||||
position_table_param = TruncatedNormal(0.02)
|
||||
|
||||
# Position embedding
|
||||
self.position_embedding = nn.Embedding(
|
||||
config.seq_length,
|
||||
config.embedding_size,
|
||||
|
@ -671,11 +707,13 @@ class PanguAlpha_Model(nn.Cell):
|
|||
self.position_embedding.gather.shard(((1, 1), (config.dp,)))
|
||||
self.position_embedding.expand.shard(((config.dp, 1),))
|
||||
self.blocks = nn.CellList()
|
||||
# Total fusion groups for HCCL operators. Specifically, the same tyep HCCL operators in same group will be fused.
|
||||
fusion_group_num = 4
|
||||
fusion_group_size = config.num_layers // fusion_group_num
|
||||
fusion_group_size = max(fusion_group_size, 1)
|
||||
|
||||
num_layers = config.num_layers
|
||||
# If top_query_attention enabled, replace the last normal self-attention layers with this top_query_attention layer
|
||||
if config.use_top_query_attention:
|
||||
num_layers -= 1
|
||||
self.num_layers = num_layers
|
||||
|
@ -683,7 +721,10 @@ class PanguAlpha_Model(nn.Cell):
|
|||
|
||||
for i in range(num_layers):
|
||||
per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
|
||||
# Each layer will be remoputed in the backward process. The output activation of each layer will be saved,
|
||||
# in other words, in backward process each block will be almosttotally recomputed.
|
||||
per_block.recompute()
|
||||
# Dropout will not be recomputed to ensure the consistency between forward and the corresponding backward.
|
||||
per_block.attention.dropout.dropout_gen_mask.recompute(False)
|
||||
per_block.attention.prob_dropout.dropout_gen_mask.recompute(False)
|
||||
per_block.output.dropout.dropout_gen_mask.recompute(False)
|
||||
|
@ -709,6 +750,9 @@ class PanguAlpha_Model(nn.Cell):
|
|||
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
|
||||
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
|
||||
self.eod_reset = config.eod_reset
|
||||
# If top_query_attention enabled, the input_position representing the position ids will be used as the index
|
||||
# for a query embedding table to obtain top query hidden states, together with the previous outputs of normal
|
||||
# self-attention layers, a new attention layer will be attached to the output of the model
|
||||
if config.use_top_query_attention:
|
||||
if config.load_ckpt_path:
|
||||
# Loading the embedding table from the ckpt path:
|
||||
|
@ -745,19 +789,24 @@ class PanguAlpha_Model(nn.Cell):
|
|||
if not self.use_past:
|
||||
layer_past = self.past
|
||||
|
||||
# Word embedding
|
||||
input_embedding, embedding_table = self.word_embedding(input_ids)
|
||||
# If eod_reset disabled, there will be only one input from the dataset, i.e., input_ids
|
||||
# and the corresponding input_position and attention_mask will be derived from it.
|
||||
if not self.eod_reset:
|
||||
batch_size, seq_length = F.shape(input_ids)
|
||||
input_position = F.tuple_to_array(F.make_range(seq_length))
|
||||
input_position = P.Tile()(input_position, (batch_size, 1))
|
||||
attention_mask = self.get_attention_mask(input_mask)
|
||||
position_embedding = self.position_embedding(input_position)
|
||||
# Input features [bs, seq_length, embedding_size]
|
||||
hidden_states = self.add(input_embedding, position_embedding)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = P.Cast()(hidden_states, mstype.float16)
|
||||
attention_mask = self.expand_dims(attention_mask, 1)
|
||||
|
||||
present_layer = ()
|
||||
# Loop through each self-attention layer
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, present = self.blocks[i](hidden_states,
|
||||
attention_mask, layer_past)
|
||||
|
@ -766,12 +815,12 @@ class PanguAlpha_Model(nn.Cell):
|
|||
output_state = self.layernorm(hidden_states)
|
||||
output_state = F.cast(output_state, self.dtype)
|
||||
|
||||
# Top query attention layer
|
||||
if self.use_top_query_attention:
|
||||
top_query_hidden_states = self.top_query_embedding(input_position)
|
||||
output_state, present = self.top_query_layer(output_state, top_query_hidden_states,
|
||||
attention_mask, layer_past)
|
||||
present_layer = present_layer + (present,)
|
||||
|
||||
return output_state, present_layer, embedding_table
|
||||
|
||||
|
||||
|
@ -799,6 +848,7 @@ class PanguAlpha_Head(nn.Cell):
|
|||
|
||||
def construct(self, state, embedding_table):
|
||||
state = P.Reshape()(state, (-1, self.embedding_size))
|
||||
# output logits over vocabulary [bs*seq_length, vocab_size]
|
||||
logits = self.matmul(state, self.cast(embedding_table, self.dtype))
|
||||
return logits
|
||||
|
||||
|
@ -817,7 +867,9 @@ class PanguAlpha(nn.Cell):
|
|||
"""
|
||||
def __init__(self, config):
|
||||
super(PanguAlpha, self).__init__()
|
||||
# Network backbone of PanguAlpha
|
||||
self.backbone = PanguAlpha_Model(config)
|
||||
# Network head to get logits over vocabulary
|
||||
self.head = PanguAlpha_Head(config)
|
||||
|
||||
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, past=None):
|
||||
|
@ -844,6 +896,7 @@ class CrossEntropyLoss(nn.Cell):
|
|||
self.mean = P.ReduceMean()
|
||||
self.sum = P.ReduceSum().shard(((config.dp, config.mp),))
|
||||
self.onehot = P.OneHot().shard(((config.dp, config.mp), (), ()))
|
||||
# on/off value for onehot, for smooth labeling, modify the off_value
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.vocab_size = config.vocab_size
|
||||
|
@ -868,7 +921,9 @@ class CrossEntropyLoss(nn.Cell):
|
|||
r"""
|
||||
Compute loss using logits, label and input mask
|
||||
"""
|
||||
# [bs*seq_length, vocab_size]
|
||||
logits = F.cast(logits, mstype.float32)
|
||||
# LogSoftmax for logits over last dimension
|
||||
_, logit_max = self.max(logits)
|
||||
logit_sub = self.sub(logits, logit_max)
|
||||
logit_exp = self.exp(logit_sub)
|
||||
|
@ -876,12 +931,17 @@ class CrossEntropyLoss(nn.Cell):
|
|||
exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1))
|
||||
softmax_result = self.div(logit_exp, exp_sum)
|
||||
log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
|
||||
|
||||
# Flatten label to [bs*seq_length]
|
||||
label = P.Reshape()(label, (-1,))
|
||||
# Get onehot label [bs*seq_length, vocab_size]
|
||||
one_hot_label = self.onehot(label, self.vocab_size, self.on_value,
|
||||
self.off_value)
|
||||
# Cross-Entropy loss
|
||||
loss = self.mul(log_softmax_result, one_hot_label)
|
||||
loss_unsum = self.neg(loss)
|
||||
loss_reduce = self.sum(loss_unsum, -1)
|
||||
# input_mask indicates whether there is padded inputs and for padded inputs it will not be counted into loss
|
||||
input_mask = P.Reshape()(input_mask, (-1,))
|
||||
numerator = self.sum2(self.mul2(loss_reduce, input_mask))
|
||||
|
||||
|
@ -909,6 +969,7 @@ class PanguAlphaWithLoss(nn.Cell):
|
|||
super(PanguAlphaWithLoss, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.loss = loss
|
||||
# id for end_of_sentence, 6 in the vocabulary
|
||||
self.eos_token = eos_token
|
||||
self.slice = P.StridedSlice().shard(((config.dp, 1),))
|
||||
self.not_equal = P.NotEqual().shard(((config.dp, 1), ()))
|
||||
|
@ -922,6 +983,10 @@ class PanguAlphaWithLoss(nn.Cell):
|
|||
r"""
|
||||
PanguAlphaWithLoss
|
||||
"""
|
||||
# input_ids [bs, seq_length+1]
|
||||
# input_position [bs, seq_length] only available when eod_reset enabled
|
||||
# attention_mask [bs, seq_length, seq_length] only available when eod-reset enabled
|
||||
# Get input tokens [bs, seq_length]
|
||||
tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1))
|
||||
|
||||
if self.eod_reset:
|
||||
|
@ -929,12 +994,14 @@ class PanguAlphaWithLoss(nn.Cell):
|
|||
attention_mask = self.slice_mask(attention_mask, (0, 0, 0),
|
||||
(self.batch_size, self.len, self.len),
|
||||
(1, 1, 1))
|
||||
|
||||
# Check whether there is padding in inputs
|
||||
input_mask = F.cast(self.not_equal(tokens, self.eos_token),
|
||||
mstype.float32)
|
||||
logits = self.network(tokens, input_mask, input_position, attention_mask)
|
||||
# Get label corresponding to input tokens
|
||||
labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1),
|
||||
(1, 1))
|
||||
# Loss
|
||||
output = self.loss(logits, labels, input_mask)
|
||||
return output
|
||||
|
||||
|
|
|
@ -50,13 +50,17 @@ class PANGUALPHAConfig:
|
|||
self.embedding_size = embedding_size
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
# The expand ratio of feature size in FFN
|
||||
self.expand_ratio = expand_ratio
|
||||
# Use post-layernorm or pre-layernrom, default:pre-layernorm
|
||||
self.post_layernorm_residual = post_layernorm_residual
|
||||
self.dropout_rate = dropout_rate
|
||||
self.compute_dtype = compute_dtype
|
||||
# Whether use incremental inference
|
||||
self.use_past = use_past
|
||||
self.dp = data_parallel_num
|
||||
self.mp = model_parallel_num
|
||||
# Whether use self implemented layernorm
|
||||
self.self_layernorm = self_layernorm
|
||||
self.stage_num = stage_num
|
||||
self.micro_size = micro_size
|
||||
|
|
|
@ -45,6 +45,7 @@ def _clip_grad(clip_type, clip_value, grad):
|
|||
if clip_type not in [0, 1]:
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
# 0 for clip_by_value and 1 for clip_by_norm
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(
|
||||
grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
|
@ -107,12 +108,14 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
|
|||
def construct(self, input_ids, input_position=None, attention_mask=None, layer_past=None, sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
# Forward process
|
||||
loss = self.network(input_ids, input_position, attention_mask)
|
||||
scaling_sens = self.scale_sense
|
||||
|
||||
# alloc status and clear should be right before gradoperation
|
||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
||||
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
|
||||
# Backward process using loss scale
|
||||
grads = self.grad(self.network,
|
||||
weights)(input_ids,
|
||||
input_position, attention_mask,
|
||||
|
@ -129,8 +132,11 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
|
|||
grads = self.hyper_map(
|
||||
F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE),
|
||||
grads)
|
||||
# Check whether overflow
|
||||
cond = self.get_overflow_status(status, grads)
|
||||
overflow = self.process_loss_scale(cond)
|
||||
# If overflow, surpass weights update
|
||||
# if not, update weights
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
|
|
|
@ -70,7 +70,9 @@ class GlobalNorm(nn.Cell):
|
|||
self.values.append(Tensor([self.group_size*1.0], mstype.float32))
|
||||
self.values = tuple(self.values)
|
||||
def construct(self, grads):
|
||||
# Square sum of gradients for current rank
|
||||
square_sum_dp = self.hyper_map(get_square_sum, grads, self.values)
|
||||
# Global square sum of gradients
|
||||
global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp)))
|
||||
return global_norms
|
||||
|
||||
|
|
|
@ -77,10 +77,12 @@ def run_train(args_opt):
|
|||
The main training process.
|
||||
"""
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
# Set execution mode
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
device_id=device_id)
|
||||
context.set_context(variable_memory_max_size="30GB")
|
||||
# Set parallel context
|
||||
if args_opt.distribute == "true":
|
||||
D.init()
|
||||
device_num = D.get_group_size()
|
||||
|
@ -102,6 +104,8 @@ def run_train(args_opt):
|
|||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
# Set model property
|
||||
model_parallel_num = args_opt.tensor_model_parallel_num
|
||||
data_parallel_num = int(device_num / model_parallel_num)
|
||||
batch_size = args_opt.per_batch_size * device_num
|
||||
|
@ -124,18 +128,23 @@ def run_train(args_opt):
|
|||
eod_reset=bool(args_opt.eod_reset),
|
||||
word_emb_dp=True)
|
||||
print("===config is: ", config, flush=True)
|
||||
|
||||
# Define network
|
||||
pangu_alpha = PanguAlpha(config)
|
||||
loss = CrossEntropyLoss(config)
|
||||
pangu_alpha_with_loss = PanguAlphaWithLoss(config, pangu_alpha, loss)
|
||||
pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss)
|
||||
|
||||
print("=====args_opt is: ", args_opt, flush=True)
|
||||
|
||||
# Warm-up and cosine decay learning rate
|
||||
lr = LearningRate(learning_rate=args_opt.start_lr,
|
||||
end_learning_rate=args_opt.end_lr,
|
||||
warmup_steps=args_opt.warmup_step,
|
||||
decay_steps=200000,
|
||||
lr_scale=1)
|
||||
|
||||
# Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest
|
||||
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
|
||||
params = pangu_alpha.trainable_params()
|
||||
decay_params = list(filter(decay_filter, params))
|
||||
|
@ -153,8 +162,10 @@ def run_train(args_opt):
|
|||
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||||
else:
|
||||
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
|
||||
# Initial scaling sens
|
||||
loss_scale_value = math.pow(2, 32)
|
||||
epoch_num = args_opt.epoch_size
|
||||
# Dataset loading mindrecord files
|
||||
ds = create_dataset(config.batch_size, data_path=args_opt.data_url,
|
||||
data_start_index=0, eod_reset=config.eod_reset,
|
||||
eod_id=args_opt.eod_id, device_num=device_num, rank=rank, epoch=epoch_num)
|
||||
|
|
Loading…
Reference in New Issue