!16355 add comments for PanguAlpha

From: @alouhahahahaha
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng,@stsuteng
This commit is contained in:
mindspore-ci-bot 2021-05-18 09:04:29 +08:00 committed by Gitee
commit f71a490400
8 changed files with 126 additions and 6 deletions

View File

@ -43,11 +43,13 @@ def run_predict(args_opt):
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
local_rank = rank_id local_rank = rank_id
print('local_rank:{}, device id:{} start to run...'.format(local_rank, device_id), flush=True) print('local_rank:{}, device id:{} start to run...'.format(local_rank, device_id), flush=True)
# Set execution mode
context.set_context(save_graphs=False, context.set_context(save_graphs=False,
mode=context.GRAPH_MODE, mode=context.GRAPH_MODE,
device_target="Ascend", device_target="Ascend",
device_id=device_id) device_id=device_id)
context.set_context(variable_memory_max_size="30GB") context.set_context(variable_memory_max_size="30GB")
# Set parallel context
if args_opt.distribute == "true": if args_opt.distribute == "true":
D.init() D.init()
device_num = D.get_group_size() device_num = D.get_group_size()
@ -71,6 +73,7 @@ def run_predict(args_opt):
rank = 0 rank = 0
device_num = 1 device_num = 1
# Set model property
model_parallel_num = args_opt.tensor_model_parallel_num model_parallel_num = args_opt.tensor_model_parallel_num
data_parallel_num = int(device_num / model_parallel_num) data_parallel_num = int(device_num / model_parallel_num)
per_batch_size = args_opt.per_batch_size per_batch_size = args_opt.per_batch_size
@ -99,10 +102,13 @@ def run_predict(args_opt):
print("=====args_opt is: ", args_opt, flush=True) print("=====args_opt is: ", args_opt, flush=True)
ckpt_name = args_opt.load_ckpt_name ckpt_name = args_opt.load_ckpt_name
# Define network
pangu_alpha = PanguAlpha(config) pangu_alpha = PanguAlpha(config)
eval_net = EvalNet(pangu_alpha) eval_net = EvalNet(pangu_alpha)
eval_net.set_train(False) eval_net.set_train(False)
model_predict = Model(eval_net) model_predict = Model(eval_net)
# Compile network and obtain tensor layout for loading ckpt
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
predict_layout = model_predict.infer_predict_layout(inputs_np) predict_layout = model_predict.infer_predict_layout(inputs_np)
print("======start load_distributed checkpoint", flush=True) print("======start load_distributed checkpoint", flush=True)
@ -111,19 +117,24 @@ def run_predict(args_opt):
ckpt_file_list = [os.path.join(args_opt.load_ckpt_path, f"{ckpt_name}_{ckpt_rank}.ckpt") for ckpt_rank in ckpt_file_list = [os.path.join(args_opt.load_ckpt_path, f"{ckpt_name}_{ckpt_rank}.ckpt") for ckpt_rank in
range(0, 512)] range(0, 512)]
print(f"Loading from path {ckpt_file_list[0]}", flush=True) print(f"Loading from path {ckpt_file_list[0]}", flush=True)
# Load checkpoint files
load_distributed_checkpoint(eval_net, ckpt_file_list, predict_layout) load_distributed_checkpoint(eval_net, ckpt_file_list, predict_layout)
print("================load param ok=================", flush=True) print("================load param ok=================", flush=True)
from src.tokenization_jieba import JIEBATokenizer from src.tokenization_jieba import JIEBATokenizer
from src.generate import generate from src.generate import generate
# Define tokenizer
tokenizer = JIEBATokenizer(os.path.join(args_opt.tokenizer_path, 'vocab10.vocab'), tokenizer = JIEBATokenizer(os.path.join(args_opt.tokenizer_path, 'vocab10.vocab'),
os.path.join(args_opt.tokenizer_path, 'vocab10.model')) os.path.join(args_opt.tokenizer_path, 'vocab10.model'))
# Tokenize input sentence to ids
sample = "今天是一个好天气" sample = "今天是一个好天气"
tokenized_token = tokenizer.tokenize(sample) tokenized_token = tokenizer.tokenize(sample)
start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token) start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token)
input_ids = np.array(start_sentence).reshape(1, -1) input_ids = np.array(start_sentence).reshape(1, -1)
# Call inference
output_ids = generate(model_predict, input_ids, config.seq_length, 9) output_ids = generate(model_predict, input_ids, config.seq_length, 9)
# Decode output ids to sentence
output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist()) output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist())
print('Output is:', output_samples, flush=True) print('Output is:', output_samples, flush=True)

View File

@ -40,16 +40,22 @@ def get_input_data(input_ids, eod_id, rank, dis):
input_ids = input_ids[rank*dis: (rank+1)*dis] input_ids = input_ids[rank*dis: (rank+1)*dis]
seq_length = input_ids.shape[1] - 1 seq_length = input_ids.shape[1] - 1
# Initialize position_ids and attention_mask
batch_input_ids = input_ids batch_input_ids = input_ids
batch_position_ids = np.ones((dis, seq_length)) batch_position_ids = np.ones((dis, seq_length))
batch_attention_mask = np.ones((dis, seq_length, seq_length)) batch_attention_mask = np.ones((dis, seq_length, seq_length))
# Loop through batches
for bs_i, _ in enumerate(range(len(input_ids))): for bs_i, _ in enumerate(range(len(input_ids))):
# Get normal position_ids and attention_mask
local_ids = input_ids[bs_i] local_ids = input_ids[bs_i]
batch_attention_mask[bs_i] = np.tril(np.ones(shape=(seq_length, seq_length))) batch_attention_mask[bs_i] = np.tril(np.ones(shape=(seq_length, seq_length)))
batch_position_ids[bs_i] = np.arange(seq_length) batch_position_ids[bs_i] = np.arange(seq_length)
# Find eod_of_document
eod_index = batch_position_ids[bs_i, local_ids[:-1] == eod_id].astype(np.int32) eod_index = batch_position_ids[bs_i, local_ids[:-1] == eod_id].astype(np.int32)
prev_index = 0 prev_index = 0
for i in range(eod_index.size): for i in range(eod_index.size):
# Reset position_ids and attention_mask considering EOD
index = eod_index[i] index = eod_index[i]
batch_attention_mask[bs_i, (index+1):, :(index+1)] = 0 batch_attention_mask[bs_i, (index+1):, :(index+1)] = 0
batch_position_ids[bs_i, (index+1):] -= (index + 1 - prev_index) batch_position_ids[bs_i, (index+1):] -= (index + 1 - prev_index)
@ -76,6 +82,8 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_
dataset_restore: the dataset for training or evaluating dataset_restore: the dataset for training or evaluating
""" """
ds.config.set_seed(1) ds.config.set_seed(1)
# Get path for source data files
home_path = os.path.join(os.getcwd(), data_path) home_path = os.path.join(os.getcwd(), data_path)
files = os.listdir(data_path) files = os.listdir(data_path)
dis = int(batch_size / device_num) dis = int(batch_size / device_num)
@ -89,9 +97,12 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_
if not name.endswith(".db") if not name.endswith(".db")
] ]
# Load data files and preprocess
dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False) dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False)
type_cast_op = C.TypeCast(mstype.int32) type_cast_op = C.TypeCast(mstype.int32)
type_cast_op_float = C.TypeCast(mstype.float16) type_cast_op_float = C.TypeCast(mstype.float16)
# If eod_reset enabled, another two inputs will be generated through input_ids
if eod_reset: if eod_reset:
map_func = (lambda input_ids: get_input_data(input_ids, eod_id, rank, dis)) map_func = (lambda input_ids: get_input_data(input_ids, eod_id, rank, dis))
dataset = dataset.batch(batch_size, drop_remainder=drop) dataset = dataset.batch(batch_size, drop_remainder=drop)

View File

@ -37,22 +37,30 @@ def generate(model, origin_inputs, seq_length, end_token=50256):
seq_length = seq_length seq_length = seq_length
_, valid_length = origin_inputs.shape _, valid_length = origin_inputs.shape
pad_length = seq_length - origin_inputs.shape[-1] pad_length = seq_length - origin_inputs.shape[-1]
# Pad original inputs to seq_length
input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0)) input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
print("input_ids is ", input_ids) print("input_ids is ", input_ids)
# A single loop generates one token, loop until reaching target seq_length or generating eod token
while valid_length < seq_length: while valid_length < seq_length:
inputs = Tensor(input_ids, mstype.int32) inputs = Tensor(input_ids, mstype.int32)
# Call a single inference
probs, p_args = model.predict(inputs) probs, p_args = model.predict(inputs)
# Get the topk value and index for the current position from the padded outputs
probs = probs.asnumpy()[valid_length-1, :] probs = probs.asnumpy()[valid_length-1, :]
p_args = p_args.asnumpy()[valid_length-1, :] p_args = p_args.asnumpy()[valid_length-1, :]
# Random select a token as final output for this round
p = probs p = probs
p = p / sum(p) p = p / sum(p)
target_index = np.random.choice(len(p), p=p) target_index = np.random.choice(len(p), p=p)
# Stop judgment
if p_args[target_index] == end_token or valid_length == seq_length-1: if p_args[target_index] == end_token or valid_length == seq_length-1:
outputs = input_ids outputs = input_ids
break break
# Modify input_ids with newly generated token
input_ids[0][valid_length] = p_args[target_index] input_ids[0][valid_length] = p_args[target_index]
valid_length += 1 valid_length += 1
# Return valid outputs out of padded outputs
length = np.sum(outputs != 0) length = np.sum(outputs != 0)
outputs = outputs[0][:length] outputs = outputs[0][:length]
return outputs return outputs

View File

@ -203,7 +203,9 @@ class Output(nn.Cell):
super(Output, self).__init__() super(Output, self).__init__()
input_size = config.embedding_size input_size = config.embedding_size
output_size = config.embedding_size * config.expand_ratio output_size = config.embedding_size * config.expand_ratio
# Project to expand_ratio*embedding_size
self.mapping = Mapping_output(config, input_size, output_size) self.mapping = Mapping_output(config, input_size, output_size)
# Project back to embedding_size
self.projection = Mapping(config, output_size, input_size, scale) self.projection = Mapping(config, output_size, input_size, scale)
self.activation = nn.GELU() self.activation = nn.GELU()
self.activation.gelu.shard(((config.dp, 1, config.mp),)) self.activation.gelu.shard(((config.dp, 1, config.mp),))
@ -212,8 +214,10 @@ class Output(nn.Cell):
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),)) self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
def construct(self, x): def construct(self, x):
# [bs, seq_length, expand_ratio*embedding_size]
hidden = self.activation(self.mapping(x)) hidden = self.activation(self.mapping(x))
output = self.projection(hidden) output = self.projection(hidden)
# [bs, seq_length, expand_ratio]
output = self.dropout(output) output = self.dropout(output)
return output return output
@ -235,6 +239,7 @@ class AttentionMask(nn.Cell):
((config.dp, 1, 1), (config.dp, 1, 1))) # yzz: use 64, 1, 1? ((config.dp, 1, 1), (config.dp, 1, 1))) # yzz: use 64, 1, 1?
self.expand_dim = P.ExpandDims().shard(((1, 1),)) self.expand_dim = P.ExpandDims().shard(((1, 1),))
ones = np.ones(shape=(config.seq_length, config.seq_length)) ones = np.ones(shape=(config.seq_length, config.seq_length))
# Default lower triangle mask matrix
self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32) self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
self.multiply = P.Mul().shard(((config.dp, 1, 1), (1, 1, 1))) self.multiply = P.Mul().shard(((config.dp, 1, 1), (1, 1, 1)))
@ -245,12 +250,14 @@ class AttentionMask(nn.Cell):
input_shape = P.Shape()(input_mask) input_shape = P.Shape()(input_mask)
shape_right = (input_shape[0], 1, input_shape[1]) shape_right = (input_shape[0], 1, input_shape[1])
shape_left = input_shape + (1,) shape_left = input_shape + (1,)
# Mask the padded inputs
mask_left = self.reshape(input_mask, shape_left) mask_left = self.reshape(input_mask, shape_left)
mask_right = self.reshape(input_mask, shape_right) mask_right = self.reshape(input_mask, shape_right)
attention_mask = self.mul(mask_left, mask_right) attention_mask = self.mul(mask_left, mask_right)
lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0) lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0)
# [bs, seq_length, seq_length]
attention_mask = self.multiply( attention_mask = self.multiply(
attention_mask, lower_traiangle) #bs seq_length seq_length attention_mask, lower_traiangle)
return attention_mask return attention_mask
@ -305,7 +312,9 @@ class Attention(nn.Cell):
""" """
def __init__(self, config, scale=1.0, layer_idx=None): def __init__(self, config, scale=1.0, layer_idx=None):
super(Attention, self).__init__() super(Attention, self).__init__()
# Attention mask matrix
self.get_attention_mask = AttentionMask(config) self.get_attention_mask = AttentionMask(config)
# Output layer
self.projection = Mapping(config, config.embedding_size, self.projection = Mapping(config, config.embedding_size,
config.embedding_size, scale) config.embedding_size, scale)
self.transpose = P.Transpose().shard(((config.dp, 1, config.mp, 1),)) self.transpose = P.Transpose().shard(((config.dp, 1, config.mp, 1),))
@ -313,6 +322,7 @@ class Attention(nn.Cell):
((config.dp, config.mp, 1, 1),)) ((config.dp, config.mp, 1, 1),))
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.n_head = config.num_heads self.n_head = config.num_heads
# embedding size per head
self.size_per_head = config.embedding_size // self.n_head self.size_per_head = config.embedding_size // self.n_head
self.concat_k = P.Concat(axis=3) self.concat_k = P.Concat(axis=3)
self.concat_v = P.Concat(axis=2) self.concat_v = P.Concat(axis=2)
@ -329,6 +339,7 @@ class Attention(nn.Cell):
((config.dp, 1, 1, 1), (1,))) ((config.dp, 1, 1, 1), (1,)))
self.add = P.TensorAdd().shard( self.add = P.TensorAdd().shard(
((config.dp, 1, 1, 1), (config.dp, config.mp, 1, 1))) ((config.dp, 1, 1, 1), (config.dp, config.mp, 1, 1)))
# Normalize factor for attention, sqrt(dk) as widely used
if self.scale: if self.scale:
self.scale_factor = Tensor(math.sqrt(self.size_per_head)) self.scale_factor = Tensor(math.sqrt(self.size_per_head))
if layer_idx is not None: if layer_idx is not None:
@ -347,16 +358,19 @@ class Attention(nn.Cell):
self.softmax.softmax.shard(((config.dp, config.mp, 1),)) self.softmax.softmax.shard(((config.dp, config.mp, 1),))
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),)) self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
# Query
self.dense1 = nn.Dense(config.embedding_size, self.dense1 = nn.Dense(config.embedding_size,
config.embedding_size).to_float( config.embedding_size).to_float(
config.compute_dtype) config.compute_dtype)
self.dense1.matmul.shard(((config.dp, 1), (config.mp, 1))) self.dense1.matmul.shard(((config.dp, 1), (config.mp, 1)))
self.dense1.bias_add.shard(((config.dp, config.mp), (config.mp,))) self.dense1.bias_add.shard(((config.dp, config.mp), (config.mp,)))
# Key
self.dense2 = nn.Dense(config.embedding_size, self.dense2 = nn.Dense(config.embedding_size,
config.embedding_size).to_float( config.embedding_size).to_float(
config.compute_dtype) config.compute_dtype)
self.dense2.matmul.shard(((config.dp, 1), (config.mp, 1))) self.dense2.matmul.shard(((config.dp, 1), (config.mp, 1)))
self.dense2.bias_add.shard(((config.dp, config.mp), (config.mp,))) self.dense2.bias_add.shard(((config.dp, config.mp), (config.mp,)))
# Value
self.dense3 = nn.Dense(config.embedding_size, self.dense3 = nn.Dense(config.embedding_size,
config.embedding_size).to_float( config.embedding_size).to_float(
config.compute_dtype) config.compute_dtype)
@ -380,18 +394,22 @@ class Attention(nn.Cell):
original_shape = F.shape(x) original_shape = F.shape(x)
x = F.reshape(x, (-1, original_shape[-1])) x = F.reshape(x, (-1, original_shape[-1]))
# Self attention: query, key, value are derived from the same inputs
query = self.dense1(x) query = self.dense1(x)
key = self.dense2(x) key = self.dense2(x)
value = self.dense3(x) value = self.dense3(x)
# [bs, num_heads, seq_length, size_per_head]
query = self.transpose( query = self.transpose(
F.reshape( F.reshape(
query, query,
(-1, original_shape[1], self.n_head, self.size_per_head)), (-1, original_shape[1], self.n_head, self.size_per_head)),
(0, 2, 1, 3)) (0, 2, 1, 3))
# [bs, num_heads, size_per_head, seq_length]
key = self.transpose( key = self.transpose(
F.reshape( F.reshape(
key, (-1, original_shape[1], self.n_head, self.size_per_head)), key, (-1, original_shape[1], self.n_head, self.size_per_head)),
(0, 2, 3, 1)) (0, 2, 3, 1))
# [bs, num_heads, seq_length, size_per_head]
value = self.transpose( value = self.transpose(
F.reshape( F.reshape(
value, value,
@ -403,8 +421,11 @@ class Attention(nn.Cell):
key = self.concat_k((past_key, key)) key = self.concat_k((past_key, key))
value = self.concat_v(past_value, value) value = self.concat_v(past_value, value)
layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value]) layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value])
# Self-attention considering attention mask
attention = self._attn(query, key, value, attention_mask) attention = self._attn(query, key, value, attention_mask)
# [bs, seq_length, embedding_size]
attention_merge = self.merge_heads(attention) attention_merge = self.merge_heads(attention)
# Output
output = self.projection(attention_merge) output = self.projection(attention_merge)
output = self.dropout(output) output = self.dropout(output)
return output, layer_present return output, layer_present
@ -454,11 +475,14 @@ class Attention(nn.Cell):
Returns: Returns:
weighted_values: Tensor, the weighted sum scores weighted_values: Tensor, the weighted sum scores
""" """
# Normalize query and key before MatMul, default off
if not self.scale: if not self.scale:
query = query / F.cast(self.coeff, F.dtype(query)) query = query / F.cast(self.coeff, F.dtype(query))
key = key / F.cast(self.coeff, F.dtype(key)) key = key / F.cast(self.coeff, F.dtype(key))
# Attention score [bs, num_heads, seq_length, seq_length]
score = self.batch_matmul(query, key) score = self.batch_matmul(query, key)
# Normalize after query and key MatMul, default on
if self.scale: if self.scale:
score = self.real_div( score = self.real_div(
score, score,
@ -466,6 +490,7 @@ class Attention(nn.Cell):
ori_dtype = P.DType()(score) ori_dtype = P.DType()(score)
score = P.Cast()(score, mstype.float32) score = P.Cast()(score, mstype.float32)
# Minus 10000 for the position where masked to exclude them from softmax
multiplu_out = self.sub( multiplu_out = self.sub(
P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)), P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
P.Cast()(attention_mask, P.DType()(score))) P.Cast()(attention_mask, P.DType()(score)))
@ -474,13 +499,15 @@ class Attention(nn.Cell):
attention_scores = self.add(adder, score) attention_scores = self.add(adder, score)
shape = F.shape(attention_scores) shape = F.shape(attention_scores)
# attention probs
attention_probs = self.softmax( attention_probs = self.softmax(
F.reshape(attention_scores, F.reshape(attention_scores,
(shape[0], -1, shape[-1]))) # yzz modify (shape[0], -1, shape[-1])))
attention_probs = P.Cast()(attention_probs, ori_dtype) attention_probs = P.Cast()(attention_probs, ori_dtype)
attention_probs = F.reshape(attention_probs, shape) attention_probs = F.reshape(attention_probs, shape)
attention_probs = self.prob_dropout(attention_probs) attention_probs = self.prob_dropout(attention_probs)
# Weighted sum output [bs, num_heads, seq_length, size_per_head]
weighted_values = self.batch_matmul(attention_probs, value) weighted_values = self.batch_matmul(attention_probs, value)
return weighted_values return weighted_values
@ -517,11 +544,13 @@ class Block(nn.Cell):
self.attention = Attention(config, scale, layer_idx) self.attention = Attention(config, scale, layer_idx)
self.layernorm2.gamma.parallel_optimizer = False self.layernorm2.gamma.parallel_optimizer = False
self.layernorm2.beta.parallel_optimizer = False self.layernorm2.beta.parallel_optimizer = False
# Feed Forward Network, FFN
self.output = Output(config, scale) self.output = Output(config, scale)
self.post_layernorm_residual = config.post_layernorm_residual self.post_layernorm_residual = config.post_layernorm_residual
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
self.last_add = P.TensorAdd().shard( self.last_add = P.TensorAdd().shard(
((config.dp, 1, 1), (config.dp, 1, 1))) ((config.dp, 1, 1), (config.dp, 1, 1)))
# Last activation of this layer will be saved for recompute in backward process
self.last_add.recompute(False) self.last_add.recompute(False)
self.dtype = config.compute_dtype self.dtype = config.compute_dtype
@ -529,12 +558,15 @@ class Block(nn.Cell):
r""" r"""
The forward process of the block. The forward process of the block.
""" """
# [bs, seq_length, embedding_size]
input_x = self.layernorm1(x) input_x = self.layernorm1(x)
input_x = F.cast(input_x, self.dtype) input_x = F.cast(input_x, self.dtype)
attention, layer_present = self.attention(input_x, input_mask, attention, layer_present = self.attention(input_x, input_mask,
layer_past) layer_past)
# For post-layernorm the inputs for residual path are output of self-attention and output of layernorm
if self.post_layernorm_residual: if self.post_layernorm_residual:
x = self.add(input_x, attention) x = self.add(input_x, attention)
# For pre-layernorm the inputs for residual path are output of self-attention and input of this layer
else: else:
x = self.add(x, attention) x = self.add(x, attention)
@ -556,6 +588,7 @@ class QueryLayerAttention(Attention):
original_shape = F.shape(x) original_shape = F.shape(x)
x = F.reshape(x, (-1, original_shape[-1])) x = F.reshape(x, (-1, original_shape[-1]))
query_hidden_state = F.reshape(query_hidden_state, (-1, original_shape[-1])) query_hidden_state = F.reshape(query_hidden_state, (-1, original_shape[-1]))
# For query_layer_attention, query are derived from outputs of previous layer and key, value are derived from an added parameter query_embedding
query = self.dense1(query_hidden_state) query = self.dense1(query_hidden_state)
key = self.dense2(x) key = self.dense2(x)
value = self.dense3(x) value = self.dense3(x)
@ -611,7 +644,8 @@ class QueryLayer(nn.Cell):
def construct(self, x, query_hidden_state, input_mask, layer_past=None): def construct(self, x, query_hidden_state, input_mask, layer_past=None):
r""" r"""
Query Layer. Query Layer shares a similar structure with normal layer block
except that it is not a traditional self-attention.
""" """
input_x = self.layernorm1(x) input_x = self.layernorm1(x)
input_x = F.cast(input_x, self.dtype) input_x = F.cast(input_x, self.dtype)
@ -650,6 +684,7 @@ class PanguAlpha_Model(nn.Cell):
def __init__(self, config): def __init__(self, config):
super(PanguAlpha_Model, self).__init__() super(PanguAlpha_Model, self).__init__()
self.get_attention_mask = AttentionMask(config) self.get_attention_mask = AttentionMask(config)
# Word embedding
self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1) self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1)
if config.load_ckpt_path: if config.load_ckpt_path:
# Loading the embedding table from the ckpt path: # Loading the embedding table from the ckpt path:
@ -662,6 +697,7 @@ class PanguAlpha_Model(nn.Cell):
else: else:
position_table_param = TruncatedNormal(0.02) position_table_param = TruncatedNormal(0.02)
# Position embedding
self.position_embedding = nn.Embedding( self.position_embedding = nn.Embedding(
config.seq_length, config.seq_length,
config.embedding_size, config.embedding_size,
@ -671,11 +707,13 @@ class PanguAlpha_Model(nn.Cell):
self.position_embedding.gather.shard(((1, 1), (config.dp,))) self.position_embedding.gather.shard(((1, 1), (config.dp,)))
self.position_embedding.expand.shard(((config.dp, 1),)) self.position_embedding.expand.shard(((config.dp, 1),))
self.blocks = nn.CellList() self.blocks = nn.CellList()
# Total fusion groups for HCCL operators. Specifically, the same tyep HCCL operators in same group will be fused.
fusion_group_num = 4 fusion_group_num = 4
fusion_group_size = config.num_layers // fusion_group_num fusion_group_size = config.num_layers // fusion_group_num
fusion_group_size = max(fusion_group_size, 1) fusion_group_size = max(fusion_group_size, 1)
num_layers = config.num_layers num_layers = config.num_layers
# If top_query_attention enabled, replace the last normal self-attention layers with this top_query_attention layer
if config.use_top_query_attention: if config.use_top_query_attention:
num_layers -= 1 num_layers -= 1
self.num_layers = num_layers self.num_layers = num_layers
@ -683,7 +721,10 @@ class PanguAlpha_Model(nn.Cell):
for i in range(num_layers): for i in range(num_layers):
per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2) per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
# Each layer will be remoputed in the backward process. The output activation of each layer will be saved,
# in other words, in backward process each block will be almosttotally recomputed.
per_block.recompute() per_block.recompute()
# Dropout will not be recomputed to ensure the consistency between forward and the corresponding backward.
per_block.attention.dropout.dropout_gen_mask.recompute(False) per_block.attention.dropout.dropout_gen_mask.recompute(False)
per_block.attention.prob_dropout.dropout_gen_mask.recompute(False) per_block.attention.prob_dropout.dropout_gen_mask.recompute(False)
per_block.output.dropout.dropout_gen_mask.recompute(False) per_block.output.dropout.dropout_gen_mask.recompute(False)
@ -709,6 +750,9 @@ class PanguAlpha_Model(nn.Cell):
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),)) self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),)) self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
self.eod_reset = config.eod_reset self.eod_reset = config.eod_reset
# If top_query_attention enabled, the input_position representing the position ids will be used as the index
# for a query embedding table to obtain top query hidden states, together with the previous outputs of normal
# self-attention layers, a new attention layer will be attached to the output of the model
if config.use_top_query_attention: if config.use_top_query_attention:
if config.load_ckpt_path: if config.load_ckpt_path:
# Loading the embedding table from the ckpt path: # Loading the embedding table from the ckpt path:
@ -745,19 +789,24 @@ class PanguAlpha_Model(nn.Cell):
if not self.use_past: if not self.use_past:
layer_past = self.past layer_past = self.past
# Word embedding
input_embedding, embedding_table = self.word_embedding(input_ids) input_embedding, embedding_table = self.word_embedding(input_ids)
# If eod_reset disabled, there will be only one input from the dataset, i.e., input_ids
# and the corresponding input_position and attention_mask will be derived from it.
if not self.eod_reset: if not self.eod_reset:
batch_size, seq_length = F.shape(input_ids) batch_size, seq_length = F.shape(input_ids)
input_position = F.tuple_to_array(F.make_range(seq_length)) input_position = F.tuple_to_array(F.make_range(seq_length))
input_position = P.Tile()(input_position, (batch_size, 1)) input_position = P.Tile()(input_position, (batch_size, 1))
attention_mask = self.get_attention_mask(input_mask) attention_mask = self.get_attention_mask(input_mask)
position_embedding = self.position_embedding(input_position) position_embedding = self.position_embedding(input_position)
# Input features [bs, seq_length, embedding_size]
hidden_states = self.add(input_embedding, position_embedding) hidden_states = self.add(input_embedding, position_embedding)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = P.Cast()(hidden_states, mstype.float16) hidden_states = P.Cast()(hidden_states, mstype.float16)
attention_mask = self.expand_dims(attention_mask, 1) attention_mask = self.expand_dims(attention_mask, 1)
present_layer = () present_layer = ()
# Loop through each self-attention layer
for i in range(self.num_layers): for i in range(self.num_layers):
hidden_states, present = self.blocks[i](hidden_states, hidden_states, present = self.blocks[i](hidden_states,
attention_mask, layer_past) attention_mask, layer_past)
@ -766,12 +815,12 @@ class PanguAlpha_Model(nn.Cell):
output_state = self.layernorm(hidden_states) output_state = self.layernorm(hidden_states)
output_state = F.cast(output_state, self.dtype) output_state = F.cast(output_state, self.dtype)
# Top query attention layer
if self.use_top_query_attention: if self.use_top_query_attention:
top_query_hidden_states = self.top_query_embedding(input_position) top_query_hidden_states = self.top_query_embedding(input_position)
output_state, present = self.top_query_layer(output_state, top_query_hidden_states, output_state, present = self.top_query_layer(output_state, top_query_hidden_states,
attention_mask, layer_past) attention_mask, layer_past)
present_layer = present_layer + (present,) present_layer = present_layer + (present,)
return output_state, present_layer, embedding_table return output_state, present_layer, embedding_table
@ -799,6 +848,7 @@ class PanguAlpha_Head(nn.Cell):
def construct(self, state, embedding_table): def construct(self, state, embedding_table):
state = P.Reshape()(state, (-1, self.embedding_size)) state = P.Reshape()(state, (-1, self.embedding_size))
# output logits over vocabulary [bs*seq_length, vocab_size]
logits = self.matmul(state, self.cast(embedding_table, self.dtype)) logits = self.matmul(state, self.cast(embedding_table, self.dtype))
return logits return logits
@ -817,7 +867,9 @@ class PanguAlpha(nn.Cell):
""" """
def __init__(self, config): def __init__(self, config):
super(PanguAlpha, self).__init__() super(PanguAlpha, self).__init__()
# Network backbone of PanguAlpha
self.backbone = PanguAlpha_Model(config) self.backbone = PanguAlpha_Model(config)
# Network head to get logits over vocabulary
self.head = PanguAlpha_Head(config) self.head = PanguAlpha_Head(config)
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, past=None): def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, past=None):
@ -844,6 +896,7 @@ class CrossEntropyLoss(nn.Cell):
self.mean = P.ReduceMean() self.mean = P.ReduceMean()
self.sum = P.ReduceSum().shard(((config.dp, config.mp),)) self.sum = P.ReduceSum().shard(((config.dp, config.mp),))
self.onehot = P.OneHot().shard(((config.dp, config.mp), (), ())) self.onehot = P.OneHot().shard(((config.dp, config.mp), (), ()))
# on/off value for onehot, for smooth labeling, modify the off_value
self.on_value = Tensor(1.0, mstype.float32) self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32) self.off_value = Tensor(0.0, mstype.float32)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
@ -868,7 +921,9 @@ class CrossEntropyLoss(nn.Cell):
r""" r"""
Compute loss using logits, label and input mask Compute loss using logits, label and input mask
""" """
# [bs*seq_length, vocab_size]
logits = F.cast(logits, mstype.float32) logits = F.cast(logits, mstype.float32)
# LogSoftmax for logits over last dimension
_, logit_max = self.max(logits) _, logit_max = self.max(logits)
logit_sub = self.sub(logits, logit_max) logit_sub = self.sub(logits, logit_max)
logit_exp = self.exp(logit_sub) logit_exp = self.exp(logit_sub)
@ -876,12 +931,17 @@ class CrossEntropyLoss(nn.Cell):
exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1)) exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1))
softmax_result = self.div(logit_exp, exp_sum) softmax_result = self.div(logit_exp, exp_sum)
log_softmax_result = self.log(self.add(softmax_result, self.eps_const)) log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
# Flatten label to [bs*seq_length]
label = P.Reshape()(label, (-1,)) label = P.Reshape()(label, (-1,))
# Get onehot label [bs*seq_length, vocab_size]
one_hot_label = self.onehot(label, self.vocab_size, self.on_value, one_hot_label = self.onehot(label, self.vocab_size, self.on_value,
self.off_value) self.off_value)
# Cross-Entropy loss
loss = self.mul(log_softmax_result, one_hot_label) loss = self.mul(log_softmax_result, one_hot_label)
loss_unsum = self.neg(loss) loss_unsum = self.neg(loss)
loss_reduce = self.sum(loss_unsum, -1) loss_reduce = self.sum(loss_unsum, -1)
# input_mask indicates whether there is padded inputs and for padded inputs it will not be counted into loss
input_mask = P.Reshape()(input_mask, (-1,)) input_mask = P.Reshape()(input_mask, (-1,))
numerator = self.sum2(self.mul2(loss_reduce, input_mask)) numerator = self.sum2(self.mul2(loss_reduce, input_mask))
@ -909,6 +969,7 @@ class PanguAlphaWithLoss(nn.Cell):
super(PanguAlphaWithLoss, self).__init__(auto_prefix=False) super(PanguAlphaWithLoss, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.loss = loss self.loss = loss
# id for end_of_sentence, 6 in the vocabulary
self.eos_token = eos_token self.eos_token = eos_token
self.slice = P.StridedSlice().shard(((config.dp, 1),)) self.slice = P.StridedSlice().shard(((config.dp, 1),))
self.not_equal = P.NotEqual().shard(((config.dp, 1), ())) self.not_equal = P.NotEqual().shard(((config.dp, 1), ()))
@ -922,6 +983,10 @@ class PanguAlphaWithLoss(nn.Cell):
r""" r"""
PanguAlphaWithLoss PanguAlphaWithLoss
""" """
# input_ids [bs, seq_length+1]
# input_position [bs, seq_length] only available when eod_reset enabled
# attention_mask [bs, seq_length, seq_length] only available when eod-reset enabled
# Get input tokens [bs, seq_length]
tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1)) tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1))
if self.eod_reset: if self.eod_reset:
@ -929,12 +994,14 @@ class PanguAlphaWithLoss(nn.Cell):
attention_mask = self.slice_mask(attention_mask, (0, 0, 0), attention_mask = self.slice_mask(attention_mask, (0, 0, 0),
(self.batch_size, self.len, self.len), (self.batch_size, self.len, self.len),
(1, 1, 1)) (1, 1, 1))
# Check whether there is padding in inputs
input_mask = F.cast(self.not_equal(tokens, self.eos_token), input_mask = F.cast(self.not_equal(tokens, self.eos_token),
mstype.float32) mstype.float32)
logits = self.network(tokens, input_mask, input_position, attention_mask) logits = self.network(tokens, input_mask, input_position, attention_mask)
# Get label corresponding to input tokens
labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1), labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1),
(1, 1)) (1, 1))
# Loss
output = self.loss(logits, labels, input_mask) output = self.loss(logits, labels, input_mask)
return output return output

View File

@ -50,13 +50,17 @@ class PANGUALPHAConfig:
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.num_layers = num_layers self.num_layers = num_layers
self.num_heads = num_heads self.num_heads = num_heads
# The expand ratio of feature size in FFN
self.expand_ratio = expand_ratio self.expand_ratio = expand_ratio
# Use post-layernorm or pre-layernrom, default:pre-layernorm
self.post_layernorm_residual = post_layernorm_residual self.post_layernorm_residual = post_layernorm_residual
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.compute_dtype = compute_dtype self.compute_dtype = compute_dtype
# Whether use incremental inference
self.use_past = use_past self.use_past = use_past
self.dp = data_parallel_num self.dp = data_parallel_num
self.mp = model_parallel_num self.mp = model_parallel_num
# Whether use self implemented layernorm
self.self_layernorm = self_layernorm self.self_layernorm = self_layernorm
self.stage_num = stage_num self.stage_num = stage_num
self.micro_size = micro_size self.micro_size = micro_size

View File

@ -45,6 +45,7 @@ def _clip_grad(clip_type, clip_value, grad):
if clip_type not in [0, 1]: if clip_type not in [0, 1]:
return grad return grad
dt = F.dtype(grad) dt = F.dtype(grad)
# 0 for clip_by_value and 1 for clip_by_norm
if clip_type == 0: if clip_type == 0:
new_grad = C.clip_by_value( new_grad = C.clip_by_value(
grad, F.cast(F.tuple_to_array((-clip_value,)), dt), grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
@ -107,12 +108,14 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
def construct(self, input_ids, input_position=None, attention_mask=None, layer_past=None, sens=None): def construct(self, input_ids, input_position=None, attention_mask=None, layer_past=None, sens=None):
"""Defines the computation performed.""" """Defines the computation performed."""
weights = self.weights weights = self.weights
# Forward process
loss = self.network(input_ids, input_position, attention_mask) loss = self.network(input_ids, input_position, attention_mask)
scaling_sens = self.scale_sense scaling_sens = self.scale_sense
# alloc status and clear should be right before gradoperation # alloc status and clear should be right before gradoperation
status, scaling_sens = self.start_overflow_check(loss, scaling_sens) status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
# Backward process using loss scale
grads = self.grad(self.network, grads = self.grad(self.network,
weights)(input_ids, weights)(input_ids,
input_position, attention_mask, input_position, attention_mask,
@ -129,8 +132,11 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
grads = self.hyper_map( grads = self.hyper_map(
F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE),
grads) grads)
# Check whether overflow
cond = self.get_overflow_status(status, grads) cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond) overflow = self.process_loss_scale(cond)
# If overflow, surpass weights update
# if not, update weights
if overflow: if overflow:
succ = False succ = False
else: else:

View File

@ -70,7 +70,9 @@ class GlobalNorm(nn.Cell):
self.values.append(Tensor([self.group_size*1.0], mstype.float32)) self.values.append(Tensor([self.group_size*1.0], mstype.float32))
self.values = tuple(self.values) self.values = tuple(self.values)
def construct(self, grads): def construct(self, grads):
# Square sum of gradients for current rank
square_sum_dp = self.hyper_map(get_square_sum, grads, self.values) square_sum_dp = self.hyper_map(get_square_sum, grads, self.values)
# Global square sum of gradients
global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp))) global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp)))
return global_norms return global_norms

View File

@ -77,10 +77,12 @@ def run_train(args_opt):
The main training process. The main training process.
""" """
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
# Set execution mode
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", device_target="Ascend",
device_id=device_id) device_id=device_id)
context.set_context(variable_memory_max_size="30GB") context.set_context(variable_memory_max_size="30GB")
# Set parallel context
if args_opt.distribute == "true": if args_opt.distribute == "true":
D.init() D.init()
device_num = D.get_group_size() device_num = D.get_group_size()
@ -102,6 +104,8 @@ def run_train(args_opt):
else: else:
rank = 0 rank = 0
device_num = 1 device_num = 1
# Set model property
model_parallel_num = args_opt.tensor_model_parallel_num model_parallel_num = args_opt.tensor_model_parallel_num
data_parallel_num = int(device_num / model_parallel_num) data_parallel_num = int(device_num / model_parallel_num)
batch_size = args_opt.per_batch_size * device_num batch_size = args_opt.per_batch_size * device_num
@ -124,18 +128,23 @@ def run_train(args_opt):
eod_reset=bool(args_opt.eod_reset), eod_reset=bool(args_opt.eod_reset),
word_emb_dp=True) word_emb_dp=True)
print("===config is: ", config, flush=True) print("===config is: ", config, flush=True)
# Define network
pangu_alpha = PanguAlpha(config) pangu_alpha = PanguAlpha(config)
loss = CrossEntropyLoss(config) loss = CrossEntropyLoss(config)
pangu_alpha_with_loss = PanguAlphaWithLoss(config, pangu_alpha, loss) pangu_alpha_with_loss = PanguAlphaWithLoss(config, pangu_alpha, loss)
pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss) pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss)
print("=====args_opt is: ", args_opt, flush=True) print("=====args_opt is: ", args_opt, flush=True)
# Warm-up and cosine decay learning rate
lr = LearningRate(learning_rate=args_opt.start_lr, lr = LearningRate(learning_rate=args_opt.start_lr,
end_learning_rate=args_opt.end_lr, end_learning_rate=args_opt.end_lr,
warmup_steps=args_opt.warmup_step, warmup_steps=args_opt.warmup_step,
decay_steps=200000, decay_steps=200000,
lr_scale=1) lr_scale=1)
# Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower() decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
params = pangu_alpha.trainable_params() params = pangu_alpha.trainable_params()
decay_params = list(filter(decay_filter, params)) decay_params = list(filter(decay_filter, params))
@ -153,8 +162,10 @@ def run_train(args_opt):
optimizer = nn.Lamb(group_params, learning_rate=lr) optimizer = nn.Lamb(group_params, learning_rate=lr)
else: else:
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95) optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
# Initial scaling sens
loss_scale_value = math.pow(2, 32) loss_scale_value = math.pow(2, 32)
epoch_num = args_opt.epoch_size epoch_num = args_opt.epoch_size
# Dataset loading mindrecord files
ds = create_dataset(config.batch_size, data_path=args_opt.data_url, ds = create_dataset(config.batch_size, data_path=args_opt.data_url,
data_start_index=0, eod_reset=config.eod_reset, data_start_index=0, eod_reset=config.eod_reset,
eod_id=args_opt.eod_id, device_num=device_num, rank=rank, epoch=epoch_num) eod_id=args_opt.eod_id, device_num=device_num, rank=rank, epoch=epoch_num)