forked from mindspore-Ecosystem/mindspore
!16355 add comments for PanguAlpha
From: @alouhahahahaha Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsuteng,@stsuteng
This commit is contained in:
commit
f71a490400
|
@ -43,11 +43,13 @@ def run_predict(args_opt):
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
local_rank = rank_id
|
local_rank = rank_id
|
||||||
print('local_rank:{}, device id:{} start to run...'.format(local_rank, device_id), flush=True)
|
print('local_rank:{}, device id:{} start to run...'.format(local_rank, device_id), flush=True)
|
||||||
|
# Set execution mode
|
||||||
context.set_context(save_graphs=False,
|
context.set_context(save_graphs=False,
|
||||||
mode=context.GRAPH_MODE,
|
mode=context.GRAPH_MODE,
|
||||||
device_target="Ascend",
|
device_target="Ascend",
|
||||||
device_id=device_id)
|
device_id=device_id)
|
||||||
context.set_context(variable_memory_max_size="30GB")
|
context.set_context(variable_memory_max_size="30GB")
|
||||||
|
# Set parallel context
|
||||||
if args_opt.distribute == "true":
|
if args_opt.distribute == "true":
|
||||||
D.init()
|
D.init()
|
||||||
device_num = D.get_group_size()
|
device_num = D.get_group_size()
|
||||||
|
@ -71,6 +73,7 @@ def run_predict(args_opt):
|
||||||
rank = 0
|
rank = 0
|
||||||
device_num = 1
|
device_num = 1
|
||||||
|
|
||||||
|
# Set model property
|
||||||
model_parallel_num = args_opt.tensor_model_parallel_num
|
model_parallel_num = args_opt.tensor_model_parallel_num
|
||||||
data_parallel_num = int(device_num / model_parallel_num)
|
data_parallel_num = int(device_num / model_parallel_num)
|
||||||
per_batch_size = args_opt.per_batch_size
|
per_batch_size = args_opt.per_batch_size
|
||||||
|
@ -99,10 +102,13 @@ def run_predict(args_opt):
|
||||||
print("=====args_opt is: ", args_opt, flush=True)
|
print("=====args_opt is: ", args_opt, flush=True)
|
||||||
|
|
||||||
ckpt_name = args_opt.load_ckpt_name
|
ckpt_name = args_opt.load_ckpt_name
|
||||||
|
# Define network
|
||||||
pangu_alpha = PanguAlpha(config)
|
pangu_alpha = PanguAlpha(config)
|
||||||
eval_net = EvalNet(pangu_alpha)
|
eval_net = EvalNet(pangu_alpha)
|
||||||
eval_net.set_train(False)
|
eval_net.set_train(False)
|
||||||
model_predict = Model(eval_net)
|
model_predict = Model(eval_net)
|
||||||
|
|
||||||
|
# Compile network and obtain tensor layout for loading ckpt
|
||||||
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
|
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
|
||||||
predict_layout = model_predict.infer_predict_layout(inputs_np)
|
predict_layout = model_predict.infer_predict_layout(inputs_np)
|
||||||
print("======start load_distributed checkpoint", flush=True)
|
print("======start load_distributed checkpoint", flush=True)
|
||||||
|
@ -111,19 +117,24 @@ def run_predict(args_opt):
|
||||||
ckpt_file_list = [os.path.join(args_opt.load_ckpt_path, f"{ckpt_name}_{ckpt_rank}.ckpt") for ckpt_rank in
|
ckpt_file_list = [os.path.join(args_opt.load_ckpt_path, f"{ckpt_name}_{ckpt_rank}.ckpt") for ckpt_rank in
|
||||||
range(0, 512)]
|
range(0, 512)]
|
||||||
print(f"Loading from path {ckpt_file_list[0]}", flush=True)
|
print(f"Loading from path {ckpt_file_list[0]}", flush=True)
|
||||||
|
# Load checkpoint files
|
||||||
load_distributed_checkpoint(eval_net, ckpt_file_list, predict_layout)
|
load_distributed_checkpoint(eval_net, ckpt_file_list, predict_layout)
|
||||||
print("================load param ok=================", flush=True)
|
print("================load param ok=================", flush=True)
|
||||||
|
|
||||||
from src.tokenization_jieba import JIEBATokenizer
|
from src.tokenization_jieba import JIEBATokenizer
|
||||||
from src.generate import generate
|
from src.generate import generate
|
||||||
|
# Define tokenizer
|
||||||
tokenizer = JIEBATokenizer(os.path.join(args_opt.tokenizer_path, 'vocab10.vocab'),
|
tokenizer = JIEBATokenizer(os.path.join(args_opt.tokenizer_path, 'vocab10.vocab'),
|
||||||
os.path.join(args_opt.tokenizer_path, 'vocab10.model'))
|
os.path.join(args_opt.tokenizer_path, 'vocab10.model'))
|
||||||
|
|
||||||
|
# Tokenize input sentence to ids
|
||||||
sample = "今天是一个好天气"
|
sample = "今天是一个好天气"
|
||||||
tokenized_token = tokenizer.tokenize(sample)
|
tokenized_token = tokenizer.tokenize(sample)
|
||||||
start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token)
|
start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token)
|
||||||
input_ids = np.array(start_sentence).reshape(1, -1)
|
input_ids = np.array(start_sentence).reshape(1, -1)
|
||||||
|
# Call inference
|
||||||
output_ids = generate(model_predict, input_ids, config.seq_length, 9)
|
output_ids = generate(model_predict, input_ids, config.seq_length, 9)
|
||||||
|
# Decode output ids to sentence
|
||||||
output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist())
|
output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist())
|
||||||
print('Output is:', output_samples, flush=True)
|
print('Output is:', output_samples, flush=True)
|
||||||
|
|
||||||
|
|
|
@ -40,16 +40,22 @@ def get_input_data(input_ids, eod_id, rank, dis):
|
||||||
input_ids = input_ids[rank*dis: (rank+1)*dis]
|
input_ids = input_ids[rank*dis: (rank+1)*dis]
|
||||||
seq_length = input_ids.shape[1] - 1
|
seq_length = input_ids.shape[1] - 1
|
||||||
|
|
||||||
|
# Initialize position_ids and attention_mask
|
||||||
batch_input_ids = input_ids
|
batch_input_ids = input_ids
|
||||||
batch_position_ids = np.ones((dis, seq_length))
|
batch_position_ids = np.ones((dis, seq_length))
|
||||||
batch_attention_mask = np.ones((dis, seq_length, seq_length))
|
batch_attention_mask = np.ones((dis, seq_length, seq_length))
|
||||||
|
|
||||||
|
# Loop through batches
|
||||||
for bs_i, _ in enumerate(range(len(input_ids))):
|
for bs_i, _ in enumerate(range(len(input_ids))):
|
||||||
|
# Get normal position_ids and attention_mask
|
||||||
local_ids = input_ids[bs_i]
|
local_ids = input_ids[bs_i]
|
||||||
batch_attention_mask[bs_i] = np.tril(np.ones(shape=(seq_length, seq_length)))
|
batch_attention_mask[bs_i] = np.tril(np.ones(shape=(seq_length, seq_length)))
|
||||||
batch_position_ids[bs_i] = np.arange(seq_length)
|
batch_position_ids[bs_i] = np.arange(seq_length)
|
||||||
|
# Find eod_of_document
|
||||||
eod_index = batch_position_ids[bs_i, local_ids[:-1] == eod_id].astype(np.int32)
|
eod_index = batch_position_ids[bs_i, local_ids[:-1] == eod_id].astype(np.int32)
|
||||||
prev_index = 0
|
prev_index = 0
|
||||||
for i in range(eod_index.size):
|
for i in range(eod_index.size):
|
||||||
|
# Reset position_ids and attention_mask considering EOD
|
||||||
index = eod_index[i]
|
index = eod_index[i]
|
||||||
batch_attention_mask[bs_i, (index+1):, :(index+1)] = 0
|
batch_attention_mask[bs_i, (index+1):, :(index+1)] = 0
|
||||||
batch_position_ids[bs_i, (index+1):] -= (index + 1 - prev_index)
|
batch_position_ids[bs_i, (index+1):] -= (index + 1 - prev_index)
|
||||||
|
@ -76,6 +82,8 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_
|
||||||
dataset_restore: the dataset for training or evaluating
|
dataset_restore: the dataset for training or evaluating
|
||||||
"""
|
"""
|
||||||
ds.config.set_seed(1)
|
ds.config.set_seed(1)
|
||||||
|
|
||||||
|
# Get path for source data files
|
||||||
home_path = os.path.join(os.getcwd(), data_path)
|
home_path = os.path.join(os.getcwd(), data_path)
|
||||||
files = os.listdir(data_path)
|
files = os.listdir(data_path)
|
||||||
dis = int(batch_size / device_num)
|
dis = int(batch_size / device_num)
|
||||||
|
@ -89,9 +97,12 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_
|
||||||
if not name.endswith(".db")
|
if not name.endswith(".db")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Load data files and preprocess
|
||||||
dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False)
|
dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False)
|
||||||
type_cast_op = C.TypeCast(mstype.int32)
|
type_cast_op = C.TypeCast(mstype.int32)
|
||||||
type_cast_op_float = C.TypeCast(mstype.float16)
|
type_cast_op_float = C.TypeCast(mstype.float16)
|
||||||
|
|
||||||
|
# If eod_reset enabled, another two inputs will be generated through input_ids
|
||||||
if eod_reset:
|
if eod_reset:
|
||||||
map_func = (lambda input_ids: get_input_data(input_ids, eod_id, rank, dis))
|
map_func = (lambda input_ids: get_input_data(input_ids, eod_id, rank, dis))
|
||||||
dataset = dataset.batch(batch_size, drop_remainder=drop)
|
dataset = dataset.batch(batch_size, drop_remainder=drop)
|
||||||
|
|
|
@ -37,22 +37,30 @@ def generate(model, origin_inputs, seq_length, end_token=50256):
|
||||||
seq_length = seq_length
|
seq_length = seq_length
|
||||||
_, valid_length = origin_inputs.shape
|
_, valid_length = origin_inputs.shape
|
||||||
pad_length = seq_length - origin_inputs.shape[-1]
|
pad_length = seq_length - origin_inputs.shape[-1]
|
||||||
|
# Pad original inputs to seq_length
|
||||||
input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
|
input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
|
||||||
print("input_ids is ", input_ids)
|
print("input_ids is ", input_ids)
|
||||||
|
|
||||||
|
# A single loop generates one token, loop until reaching target seq_length or generating eod token
|
||||||
while valid_length < seq_length:
|
while valid_length < seq_length:
|
||||||
inputs = Tensor(input_ids, mstype.int32)
|
inputs = Tensor(input_ids, mstype.int32)
|
||||||
|
# Call a single inference
|
||||||
probs, p_args = model.predict(inputs)
|
probs, p_args = model.predict(inputs)
|
||||||
|
# Get the topk value and index for the current position from the padded outputs
|
||||||
probs = probs.asnumpy()[valid_length-1, :]
|
probs = probs.asnumpy()[valid_length-1, :]
|
||||||
p_args = p_args.asnumpy()[valid_length-1, :]
|
p_args = p_args.asnumpy()[valid_length-1, :]
|
||||||
|
# Random select a token as final output for this round
|
||||||
p = probs
|
p = probs
|
||||||
p = p / sum(p)
|
p = p / sum(p)
|
||||||
target_index = np.random.choice(len(p), p=p)
|
target_index = np.random.choice(len(p), p=p)
|
||||||
|
# Stop judgment
|
||||||
if p_args[target_index] == end_token or valid_length == seq_length-1:
|
if p_args[target_index] == end_token or valid_length == seq_length-1:
|
||||||
outputs = input_ids
|
outputs = input_ids
|
||||||
break
|
break
|
||||||
|
# Modify input_ids with newly generated token
|
||||||
input_ids[0][valid_length] = p_args[target_index]
|
input_ids[0][valid_length] = p_args[target_index]
|
||||||
valid_length += 1
|
valid_length += 1
|
||||||
|
# Return valid outputs out of padded outputs
|
||||||
length = np.sum(outputs != 0)
|
length = np.sum(outputs != 0)
|
||||||
outputs = outputs[0][:length]
|
outputs = outputs[0][:length]
|
||||||
return outputs
|
return outputs
|
||||||
|
|
|
@ -203,7 +203,9 @@ class Output(nn.Cell):
|
||||||
super(Output, self).__init__()
|
super(Output, self).__init__()
|
||||||
input_size = config.embedding_size
|
input_size = config.embedding_size
|
||||||
output_size = config.embedding_size * config.expand_ratio
|
output_size = config.embedding_size * config.expand_ratio
|
||||||
|
# Project to expand_ratio*embedding_size
|
||||||
self.mapping = Mapping_output(config, input_size, output_size)
|
self.mapping = Mapping_output(config, input_size, output_size)
|
||||||
|
# Project back to embedding_size
|
||||||
self.projection = Mapping(config, output_size, input_size, scale)
|
self.projection = Mapping(config, output_size, input_size, scale)
|
||||||
self.activation = nn.GELU()
|
self.activation = nn.GELU()
|
||||||
self.activation.gelu.shard(((config.dp, 1, config.mp),))
|
self.activation.gelu.shard(((config.dp, 1, config.mp),))
|
||||||
|
@ -212,8 +214,10 @@ class Output(nn.Cell):
|
||||||
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
|
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
# [bs, seq_length, expand_ratio*embedding_size]
|
||||||
hidden = self.activation(self.mapping(x))
|
hidden = self.activation(self.mapping(x))
|
||||||
output = self.projection(hidden)
|
output = self.projection(hidden)
|
||||||
|
# [bs, seq_length, expand_ratio]
|
||||||
output = self.dropout(output)
|
output = self.dropout(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -235,6 +239,7 @@ class AttentionMask(nn.Cell):
|
||||||
((config.dp, 1, 1), (config.dp, 1, 1))) # yzz: use 64, 1, 1?
|
((config.dp, 1, 1), (config.dp, 1, 1))) # yzz: use 64, 1, 1?
|
||||||
self.expand_dim = P.ExpandDims().shard(((1, 1),))
|
self.expand_dim = P.ExpandDims().shard(((1, 1),))
|
||||||
ones = np.ones(shape=(config.seq_length, config.seq_length))
|
ones = np.ones(shape=(config.seq_length, config.seq_length))
|
||||||
|
# Default lower triangle mask matrix
|
||||||
self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
|
self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
|
||||||
self.multiply = P.Mul().shard(((config.dp, 1, 1), (1, 1, 1)))
|
self.multiply = P.Mul().shard(((config.dp, 1, 1), (1, 1, 1)))
|
||||||
|
|
||||||
|
@ -245,12 +250,14 @@ class AttentionMask(nn.Cell):
|
||||||
input_shape = P.Shape()(input_mask)
|
input_shape = P.Shape()(input_mask)
|
||||||
shape_right = (input_shape[0], 1, input_shape[1])
|
shape_right = (input_shape[0], 1, input_shape[1])
|
||||||
shape_left = input_shape + (1,)
|
shape_left = input_shape + (1,)
|
||||||
|
# Mask the padded inputs
|
||||||
mask_left = self.reshape(input_mask, shape_left)
|
mask_left = self.reshape(input_mask, shape_left)
|
||||||
mask_right = self.reshape(input_mask, shape_right)
|
mask_right = self.reshape(input_mask, shape_right)
|
||||||
attention_mask = self.mul(mask_left, mask_right)
|
attention_mask = self.mul(mask_left, mask_right)
|
||||||
lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0)
|
lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0)
|
||||||
|
# [bs, seq_length, seq_length]
|
||||||
attention_mask = self.multiply(
|
attention_mask = self.multiply(
|
||||||
attention_mask, lower_traiangle) #bs seq_length seq_length
|
attention_mask, lower_traiangle)
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
@ -305,7 +312,9 @@ class Attention(nn.Cell):
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, scale=1.0, layer_idx=None):
|
def __init__(self, config, scale=1.0, layer_idx=None):
|
||||||
super(Attention, self).__init__()
|
super(Attention, self).__init__()
|
||||||
|
# Attention mask matrix
|
||||||
self.get_attention_mask = AttentionMask(config)
|
self.get_attention_mask = AttentionMask(config)
|
||||||
|
# Output layer
|
||||||
self.projection = Mapping(config, config.embedding_size,
|
self.projection = Mapping(config, config.embedding_size,
|
||||||
config.embedding_size, scale)
|
config.embedding_size, scale)
|
||||||
self.transpose = P.Transpose().shard(((config.dp, 1, config.mp, 1),))
|
self.transpose = P.Transpose().shard(((config.dp, 1, config.mp, 1),))
|
||||||
|
@ -313,6 +322,7 @@ class Attention(nn.Cell):
|
||||||
((config.dp, config.mp, 1, 1),))
|
((config.dp, config.mp, 1, 1),))
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.n_head = config.num_heads
|
self.n_head = config.num_heads
|
||||||
|
# embedding size per head
|
||||||
self.size_per_head = config.embedding_size // self.n_head
|
self.size_per_head = config.embedding_size // self.n_head
|
||||||
self.concat_k = P.Concat(axis=3)
|
self.concat_k = P.Concat(axis=3)
|
||||||
self.concat_v = P.Concat(axis=2)
|
self.concat_v = P.Concat(axis=2)
|
||||||
|
@ -329,6 +339,7 @@ class Attention(nn.Cell):
|
||||||
((config.dp, 1, 1, 1), (1,)))
|
((config.dp, 1, 1, 1), (1,)))
|
||||||
self.add = P.TensorAdd().shard(
|
self.add = P.TensorAdd().shard(
|
||||||
((config.dp, 1, 1, 1), (config.dp, config.mp, 1, 1)))
|
((config.dp, 1, 1, 1), (config.dp, config.mp, 1, 1)))
|
||||||
|
# Normalize factor for attention, sqrt(dk) as widely used
|
||||||
if self.scale:
|
if self.scale:
|
||||||
self.scale_factor = Tensor(math.sqrt(self.size_per_head))
|
self.scale_factor = Tensor(math.sqrt(self.size_per_head))
|
||||||
if layer_idx is not None:
|
if layer_idx is not None:
|
||||||
|
@ -347,16 +358,19 @@ class Attention(nn.Cell):
|
||||||
self.softmax.softmax.shard(((config.dp, config.mp, 1),))
|
self.softmax.softmax.shard(((config.dp, config.mp, 1),))
|
||||||
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
||||||
|
|
||||||
|
# Query
|
||||||
self.dense1 = nn.Dense(config.embedding_size,
|
self.dense1 = nn.Dense(config.embedding_size,
|
||||||
config.embedding_size).to_float(
|
config.embedding_size).to_float(
|
||||||
config.compute_dtype)
|
config.compute_dtype)
|
||||||
self.dense1.matmul.shard(((config.dp, 1), (config.mp, 1)))
|
self.dense1.matmul.shard(((config.dp, 1), (config.mp, 1)))
|
||||||
self.dense1.bias_add.shard(((config.dp, config.mp), (config.mp,)))
|
self.dense1.bias_add.shard(((config.dp, config.mp), (config.mp,)))
|
||||||
|
# Key
|
||||||
self.dense2 = nn.Dense(config.embedding_size,
|
self.dense2 = nn.Dense(config.embedding_size,
|
||||||
config.embedding_size).to_float(
|
config.embedding_size).to_float(
|
||||||
config.compute_dtype)
|
config.compute_dtype)
|
||||||
self.dense2.matmul.shard(((config.dp, 1), (config.mp, 1)))
|
self.dense2.matmul.shard(((config.dp, 1), (config.mp, 1)))
|
||||||
self.dense2.bias_add.shard(((config.dp, config.mp), (config.mp,)))
|
self.dense2.bias_add.shard(((config.dp, config.mp), (config.mp,)))
|
||||||
|
# Value
|
||||||
self.dense3 = nn.Dense(config.embedding_size,
|
self.dense3 = nn.Dense(config.embedding_size,
|
||||||
config.embedding_size).to_float(
|
config.embedding_size).to_float(
|
||||||
config.compute_dtype)
|
config.compute_dtype)
|
||||||
|
@ -380,18 +394,22 @@ class Attention(nn.Cell):
|
||||||
|
|
||||||
original_shape = F.shape(x)
|
original_shape = F.shape(x)
|
||||||
x = F.reshape(x, (-1, original_shape[-1]))
|
x = F.reshape(x, (-1, original_shape[-1]))
|
||||||
|
# Self attention: query, key, value are derived from the same inputs
|
||||||
query = self.dense1(x)
|
query = self.dense1(x)
|
||||||
key = self.dense2(x)
|
key = self.dense2(x)
|
||||||
value = self.dense3(x)
|
value = self.dense3(x)
|
||||||
|
# [bs, num_heads, seq_length, size_per_head]
|
||||||
query = self.transpose(
|
query = self.transpose(
|
||||||
F.reshape(
|
F.reshape(
|
||||||
query,
|
query,
|
||||||
(-1, original_shape[1], self.n_head, self.size_per_head)),
|
(-1, original_shape[1], self.n_head, self.size_per_head)),
|
||||||
(0, 2, 1, 3))
|
(0, 2, 1, 3))
|
||||||
|
# [bs, num_heads, size_per_head, seq_length]
|
||||||
key = self.transpose(
|
key = self.transpose(
|
||||||
F.reshape(
|
F.reshape(
|
||||||
key, (-1, original_shape[1], self.n_head, self.size_per_head)),
|
key, (-1, original_shape[1], self.n_head, self.size_per_head)),
|
||||||
(0, 2, 3, 1))
|
(0, 2, 3, 1))
|
||||||
|
# [bs, num_heads, seq_length, size_per_head]
|
||||||
value = self.transpose(
|
value = self.transpose(
|
||||||
F.reshape(
|
F.reshape(
|
||||||
value,
|
value,
|
||||||
|
@ -403,8 +421,11 @@ class Attention(nn.Cell):
|
||||||
key = self.concat_k((past_key, key))
|
key = self.concat_k((past_key, key))
|
||||||
value = self.concat_v(past_value, value)
|
value = self.concat_v(past_value, value)
|
||||||
layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value])
|
layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value])
|
||||||
|
# Self-attention considering attention mask
|
||||||
attention = self._attn(query, key, value, attention_mask)
|
attention = self._attn(query, key, value, attention_mask)
|
||||||
|
# [bs, seq_length, embedding_size]
|
||||||
attention_merge = self.merge_heads(attention)
|
attention_merge = self.merge_heads(attention)
|
||||||
|
# Output
|
||||||
output = self.projection(attention_merge)
|
output = self.projection(attention_merge)
|
||||||
output = self.dropout(output)
|
output = self.dropout(output)
|
||||||
return output, layer_present
|
return output, layer_present
|
||||||
|
@ -454,11 +475,14 @@ class Attention(nn.Cell):
|
||||||
Returns:
|
Returns:
|
||||||
weighted_values: Tensor, the weighted sum scores
|
weighted_values: Tensor, the weighted sum scores
|
||||||
"""
|
"""
|
||||||
|
# Normalize query and key before MatMul, default off
|
||||||
if not self.scale:
|
if not self.scale:
|
||||||
query = query / F.cast(self.coeff, F.dtype(query))
|
query = query / F.cast(self.coeff, F.dtype(query))
|
||||||
key = key / F.cast(self.coeff, F.dtype(key))
|
key = key / F.cast(self.coeff, F.dtype(key))
|
||||||
|
|
||||||
|
# Attention score [bs, num_heads, seq_length, seq_length]
|
||||||
score = self.batch_matmul(query, key)
|
score = self.batch_matmul(query, key)
|
||||||
|
# Normalize after query and key MatMul, default on
|
||||||
if self.scale:
|
if self.scale:
|
||||||
score = self.real_div(
|
score = self.real_div(
|
||||||
score,
|
score,
|
||||||
|
@ -466,6 +490,7 @@ class Attention(nn.Cell):
|
||||||
|
|
||||||
ori_dtype = P.DType()(score)
|
ori_dtype = P.DType()(score)
|
||||||
score = P.Cast()(score, mstype.float32)
|
score = P.Cast()(score, mstype.float32)
|
||||||
|
# Minus 10000 for the position where masked to exclude them from softmax
|
||||||
multiplu_out = self.sub(
|
multiplu_out = self.sub(
|
||||||
P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
|
P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
|
||||||
P.Cast()(attention_mask, P.DType()(score)))
|
P.Cast()(attention_mask, P.DType()(score)))
|
||||||
|
@ -474,13 +499,15 @@ class Attention(nn.Cell):
|
||||||
attention_scores = self.add(adder, score)
|
attention_scores = self.add(adder, score)
|
||||||
|
|
||||||
shape = F.shape(attention_scores)
|
shape = F.shape(attention_scores)
|
||||||
|
# attention probs
|
||||||
attention_probs = self.softmax(
|
attention_probs = self.softmax(
|
||||||
F.reshape(attention_scores,
|
F.reshape(attention_scores,
|
||||||
(shape[0], -1, shape[-1]))) # yzz modify
|
(shape[0], -1, shape[-1])))
|
||||||
attention_probs = P.Cast()(attention_probs, ori_dtype)
|
attention_probs = P.Cast()(attention_probs, ori_dtype)
|
||||||
attention_probs = F.reshape(attention_probs, shape)
|
attention_probs = F.reshape(attention_probs, shape)
|
||||||
|
|
||||||
attention_probs = self.prob_dropout(attention_probs)
|
attention_probs = self.prob_dropout(attention_probs)
|
||||||
|
# Weighted sum output [bs, num_heads, seq_length, size_per_head]
|
||||||
weighted_values = self.batch_matmul(attention_probs, value)
|
weighted_values = self.batch_matmul(attention_probs, value)
|
||||||
return weighted_values
|
return weighted_values
|
||||||
|
|
||||||
|
@ -517,11 +544,13 @@ class Block(nn.Cell):
|
||||||
self.attention = Attention(config, scale, layer_idx)
|
self.attention = Attention(config, scale, layer_idx)
|
||||||
self.layernorm2.gamma.parallel_optimizer = False
|
self.layernorm2.gamma.parallel_optimizer = False
|
||||||
self.layernorm2.beta.parallel_optimizer = False
|
self.layernorm2.beta.parallel_optimizer = False
|
||||||
|
# Feed Forward Network, FFN
|
||||||
self.output = Output(config, scale)
|
self.output = Output(config, scale)
|
||||||
self.post_layernorm_residual = config.post_layernorm_residual
|
self.post_layernorm_residual = config.post_layernorm_residual
|
||||||
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||||
self.last_add = P.TensorAdd().shard(
|
self.last_add = P.TensorAdd().shard(
|
||||||
((config.dp, 1, 1), (config.dp, 1, 1)))
|
((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||||
|
# Last activation of this layer will be saved for recompute in backward process
|
||||||
self.last_add.recompute(False)
|
self.last_add.recompute(False)
|
||||||
self.dtype = config.compute_dtype
|
self.dtype = config.compute_dtype
|
||||||
|
|
||||||
|
@ -529,12 +558,15 @@ class Block(nn.Cell):
|
||||||
r"""
|
r"""
|
||||||
The forward process of the block.
|
The forward process of the block.
|
||||||
"""
|
"""
|
||||||
|
# [bs, seq_length, embedding_size]
|
||||||
input_x = self.layernorm1(x)
|
input_x = self.layernorm1(x)
|
||||||
input_x = F.cast(input_x, self.dtype)
|
input_x = F.cast(input_x, self.dtype)
|
||||||
attention, layer_present = self.attention(input_x, input_mask,
|
attention, layer_present = self.attention(input_x, input_mask,
|
||||||
layer_past)
|
layer_past)
|
||||||
|
# For post-layernorm the inputs for residual path are output of self-attention and output of layernorm
|
||||||
if self.post_layernorm_residual:
|
if self.post_layernorm_residual:
|
||||||
x = self.add(input_x, attention)
|
x = self.add(input_x, attention)
|
||||||
|
# For pre-layernorm the inputs for residual path are output of self-attention and input of this layer
|
||||||
else:
|
else:
|
||||||
x = self.add(x, attention)
|
x = self.add(x, attention)
|
||||||
|
|
||||||
|
@ -556,6 +588,7 @@ class QueryLayerAttention(Attention):
|
||||||
original_shape = F.shape(x)
|
original_shape = F.shape(x)
|
||||||
x = F.reshape(x, (-1, original_shape[-1]))
|
x = F.reshape(x, (-1, original_shape[-1]))
|
||||||
query_hidden_state = F.reshape(query_hidden_state, (-1, original_shape[-1]))
|
query_hidden_state = F.reshape(query_hidden_state, (-1, original_shape[-1]))
|
||||||
|
# For query_layer_attention, query are derived from outputs of previous layer and key, value are derived from an added parameter query_embedding
|
||||||
query = self.dense1(query_hidden_state)
|
query = self.dense1(query_hidden_state)
|
||||||
key = self.dense2(x)
|
key = self.dense2(x)
|
||||||
value = self.dense3(x)
|
value = self.dense3(x)
|
||||||
|
@ -611,7 +644,8 @@ class QueryLayer(nn.Cell):
|
||||||
|
|
||||||
def construct(self, x, query_hidden_state, input_mask, layer_past=None):
|
def construct(self, x, query_hidden_state, input_mask, layer_past=None):
|
||||||
r"""
|
r"""
|
||||||
Query Layer.
|
Query Layer shares a similar structure with normal layer block
|
||||||
|
except that it is not a traditional self-attention.
|
||||||
"""
|
"""
|
||||||
input_x = self.layernorm1(x)
|
input_x = self.layernorm1(x)
|
||||||
input_x = F.cast(input_x, self.dtype)
|
input_x = F.cast(input_x, self.dtype)
|
||||||
|
@ -650,6 +684,7 @@ class PanguAlpha_Model(nn.Cell):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(PanguAlpha_Model, self).__init__()
|
super(PanguAlpha_Model, self).__init__()
|
||||||
self.get_attention_mask = AttentionMask(config)
|
self.get_attention_mask = AttentionMask(config)
|
||||||
|
# Word embedding
|
||||||
self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1)
|
self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1)
|
||||||
if config.load_ckpt_path:
|
if config.load_ckpt_path:
|
||||||
# Loading the embedding table from the ckpt path:
|
# Loading the embedding table from the ckpt path:
|
||||||
|
@ -662,6 +697,7 @@ class PanguAlpha_Model(nn.Cell):
|
||||||
else:
|
else:
|
||||||
position_table_param = TruncatedNormal(0.02)
|
position_table_param = TruncatedNormal(0.02)
|
||||||
|
|
||||||
|
# Position embedding
|
||||||
self.position_embedding = nn.Embedding(
|
self.position_embedding = nn.Embedding(
|
||||||
config.seq_length,
|
config.seq_length,
|
||||||
config.embedding_size,
|
config.embedding_size,
|
||||||
|
@ -671,11 +707,13 @@ class PanguAlpha_Model(nn.Cell):
|
||||||
self.position_embedding.gather.shard(((1, 1), (config.dp,)))
|
self.position_embedding.gather.shard(((1, 1), (config.dp,)))
|
||||||
self.position_embedding.expand.shard(((config.dp, 1),))
|
self.position_embedding.expand.shard(((config.dp, 1),))
|
||||||
self.blocks = nn.CellList()
|
self.blocks = nn.CellList()
|
||||||
|
# Total fusion groups for HCCL operators. Specifically, the same tyep HCCL operators in same group will be fused.
|
||||||
fusion_group_num = 4
|
fusion_group_num = 4
|
||||||
fusion_group_size = config.num_layers // fusion_group_num
|
fusion_group_size = config.num_layers // fusion_group_num
|
||||||
fusion_group_size = max(fusion_group_size, 1)
|
fusion_group_size = max(fusion_group_size, 1)
|
||||||
|
|
||||||
num_layers = config.num_layers
|
num_layers = config.num_layers
|
||||||
|
# If top_query_attention enabled, replace the last normal self-attention layers with this top_query_attention layer
|
||||||
if config.use_top_query_attention:
|
if config.use_top_query_attention:
|
||||||
num_layers -= 1
|
num_layers -= 1
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
@ -683,7 +721,10 @@ class PanguAlpha_Model(nn.Cell):
|
||||||
|
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
|
per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
|
||||||
|
# Each layer will be remoputed in the backward process. The output activation of each layer will be saved,
|
||||||
|
# in other words, in backward process each block will be almosttotally recomputed.
|
||||||
per_block.recompute()
|
per_block.recompute()
|
||||||
|
# Dropout will not be recomputed to ensure the consistency between forward and the corresponding backward.
|
||||||
per_block.attention.dropout.dropout_gen_mask.recompute(False)
|
per_block.attention.dropout.dropout_gen_mask.recompute(False)
|
||||||
per_block.attention.prob_dropout.dropout_gen_mask.recompute(False)
|
per_block.attention.prob_dropout.dropout_gen_mask.recompute(False)
|
||||||
per_block.output.dropout.dropout_gen_mask.recompute(False)
|
per_block.output.dropout.dropout_gen_mask.recompute(False)
|
||||||
|
@ -709,6 +750,9 @@ class PanguAlpha_Model(nn.Cell):
|
||||||
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
|
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
|
||||||
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
|
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
|
||||||
self.eod_reset = config.eod_reset
|
self.eod_reset = config.eod_reset
|
||||||
|
# If top_query_attention enabled, the input_position representing the position ids will be used as the index
|
||||||
|
# for a query embedding table to obtain top query hidden states, together with the previous outputs of normal
|
||||||
|
# self-attention layers, a new attention layer will be attached to the output of the model
|
||||||
if config.use_top_query_attention:
|
if config.use_top_query_attention:
|
||||||
if config.load_ckpt_path:
|
if config.load_ckpt_path:
|
||||||
# Loading the embedding table from the ckpt path:
|
# Loading the embedding table from the ckpt path:
|
||||||
|
@ -745,19 +789,24 @@ class PanguAlpha_Model(nn.Cell):
|
||||||
if not self.use_past:
|
if not self.use_past:
|
||||||
layer_past = self.past
|
layer_past = self.past
|
||||||
|
|
||||||
|
# Word embedding
|
||||||
input_embedding, embedding_table = self.word_embedding(input_ids)
|
input_embedding, embedding_table = self.word_embedding(input_ids)
|
||||||
|
# If eod_reset disabled, there will be only one input from the dataset, i.e., input_ids
|
||||||
|
# and the corresponding input_position and attention_mask will be derived from it.
|
||||||
if not self.eod_reset:
|
if not self.eod_reset:
|
||||||
batch_size, seq_length = F.shape(input_ids)
|
batch_size, seq_length = F.shape(input_ids)
|
||||||
input_position = F.tuple_to_array(F.make_range(seq_length))
|
input_position = F.tuple_to_array(F.make_range(seq_length))
|
||||||
input_position = P.Tile()(input_position, (batch_size, 1))
|
input_position = P.Tile()(input_position, (batch_size, 1))
|
||||||
attention_mask = self.get_attention_mask(input_mask)
|
attention_mask = self.get_attention_mask(input_mask)
|
||||||
position_embedding = self.position_embedding(input_position)
|
position_embedding = self.position_embedding(input_position)
|
||||||
|
# Input features [bs, seq_length, embedding_size]
|
||||||
hidden_states = self.add(input_embedding, position_embedding)
|
hidden_states = self.add(input_embedding, position_embedding)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = P.Cast()(hidden_states, mstype.float16)
|
hidden_states = P.Cast()(hidden_states, mstype.float16)
|
||||||
attention_mask = self.expand_dims(attention_mask, 1)
|
attention_mask = self.expand_dims(attention_mask, 1)
|
||||||
|
|
||||||
present_layer = ()
|
present_layer = ()
|
||||||
|
# Loop through each self-attention layer
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
hidden_states, present = self.blocks[i](hidden_states,
|
hidden_states, present = self.blocks[i](hidden_states,
|
||||||
attention_mask, layer_past)
|
attention_mask, layer_past)
|
||||||
|
@ -766,12 +815,12 @@ class PanguAlpha_Model(nn.Cell):
|
||||||
output_state = self.layernorm(hidden_states)
|
output_state = self.layernorm(hidden_states)
|
||||||
output_state = F.cast(output_state, self.dtype)
|
output_state = F.cast(output_state, self.dtype)
|
||||||
|
|
||||||
|
# Top query attention layer
|
||||||
if self.use_top_query_attention:
|
if self.use_top_query_attention:
|
||||||
top_query_hidden_states = self.top_query_embedding(input_position)
|
top_query_hidden_states = self.top_query_embedding(input_position)
|
||||||
output_state, present = self.top_query_layer(output_state, top_query_hidden_states,
|
output_state, present = self.top_query_layer(output_state, top_query_hidden_states,
|
||||||
attention_mask, layer_past)
|
attention_mask, layer_past)
|
||||||
present_layer = present_layer + (present,)
|
present_layer = present_layer + (present,)
|
||||||
|
|
||||||
return output_state, present_layer, embedding_table
|
return output_state, present_layer, embedding_table
|
||||||
|
|
||||||
|
|
||||||
|
@ -799,6 +848,7 @@ class PanguAlpha_Head(nn.Cell):
|
||||||
|
|
||||||
def construct(self, state, embedding_table):
|
def construct(self, state, embedding_table):
|
||||||
state = P.Reshape()(state, (-1, self.embedding_size))
|
state = P.Reshape()(state, (-1, self.embedding_size))
|
||||||
|
# output logits over vocabulary [bs*seq_length, vocab_size]
|
||||||
logits = self.matmul(state, self.cast(embedding_table, self.dtype))
|
logits = self.matmul(state, self.cast(embedding_table, self.dtype))
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
@ -817,7 +867,9 @@ class PanguAlpha(nn.Cell):
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(PanguAlpha, self).__init__()
|
super(PanguAlpha, self).__init__()
|
||||||
|
# Network backbone of PanguAlpha
|
||||||
self.backbone = PanguAlpha_Model(config)
|
self.backbone = PanguAlpha_Model(config)
|
||||||
|
# Network head to get logits over vocabulary
|
||||||
self.head = PanguAlpha_Head(config)
|
self.head = PanguAlpha_Head(config)
|
||||||
|
|
||||||
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, past=None):
|
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, past=None):
|
||||||
|
@ -844,6 +896,7 @@ class CrossEntropyLoss(nn.Cell):
|
||||||
self.mean = P.ReduceMean()
|
self.mean = P.ReduceMean()
|
||||||
self.sum = P.ReduceSum().shard(((config.dp, config.mp),))
|
self.sum = P.ReduceSum().shard(((config.dp, config.mp),))
|
||||||
self.onehot = P.OneHot().shard(((config.dp, config.mp), (), ()))
|
self.onehot = P.OneHot().shard(((config.dp, config.mp), (), ()))
|
||||||
|
# on/off value for onehot, for smooth labeling, modify the off_value
|
||||||
self.on_value = Tensor(1.0, mstype.float32)
|
self.on_value = Tensor(1.0, mstype.float32)
|
||||||
self.off_value = Tensor(0.0, mstype.float32)
|
self.off_value = Tensor(0.0, mstype.float32)
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
@ -868,7 +921,9 @@ class CrossEntropyLoss(nn.Cell):
|
||||||
r"""
|
r"""
|
||||||
Compute loss using logits, label and input mask
|
Compute loss using logits, label and input mask
|
||||||
"""
|
"""
|
||||||
|
# [bs*seq_length, vocab_size]
|
||||||
logits = F.cast(logits, mstype.float32)
|
logits = F.cast(logits, mstype.float32)
|
||||||
|
# LogSoftmax for logits over last dimension
|
||||||
_, logit_max = self.max(logits)
|
_, logit_max = self.max(logits)
|
||||||
logit_sub = self.sub(logits, logit_max)
|
logit_sub = self.sub(logits, logit_max)
|
||||||
logit_exp = self.exp(logit_sub)
|
logit_exp = self.exp(logit_sub)
|
||||||
|
@ -876,12 +931,17 @@ class CrossEntropyLoss(nn.Cell):
|
||||||
exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1))
|
exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1))
|
||||||
softmax_result = self.div(logit_exp, exp_sum)
|
softmax_result = self.div(logit_exp, exp_sum)
|
||||||
log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
|
log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
|
||||||
|
|
||||||
|
# Flatten label to [bs*seq_length]
|
||||||
label = P.Reshape()(label, (-1,))
|
label = P.Reshape()(label, (-1,))
|
||||||
|
# Get onehot label [bs*seq_length, vocab_size]
|
||||||
one_hot_label = self.onehot(label, self.vocab_size, self.on_value,
|
one_hot_label = self.onehot(label, self.vocab_size, self.on_value,
|
||||||
self.off_value)
|
self.off_value)
|
||||||
|
# Cross-Entropy loss
|
||||||
loss = self.mul(log_softmax_result, one_hot_label)
|
loss = self.mul(log_softmax_result, one_hot_label)
|
||||||
loss_unsum = self.neg(loss)
|
loss_unsum = self.neg(loss)
|
||||||
loss_reduce = self.sum(loss_unsum, -1)
|
loss_reduce = self.sum(loss_unsum, -1)
|
||||||
|
# input_mask indicates whether there is padded inputs and for padded inputs it will not be counted into loss
|
||||||
input_mask = P.Reshape()(input_mask, (-1,))
|
input_mask = P.Reshape()(input_mask, (-1,))
|
||||||
numerator = self.sum2(self.mul2(loss_reduce, input_mask))
|
numerator = self.sum2(self.mul2(loss_reduce, input_mask))
|
||||||
|
|
||||||
|
@ -909,6 +969,7 @@ class PanguAlphaWithLoss(nn.Cell):
|
||||||
super(PanguAlphaWithLoss, self).__init__(auto_prefix=False)
|
super(PanguAlphaWithLoss, self).__init__(auto_prefix=False)
|
||||||
self.network = network
|
self.network = network
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
|
# id for end_of_sentence, 6 in the vocabulary
|
||||||
self.eos_token = eos_token
|
self.eos_token = eos_token
|
||||||
self.slice = P.StridedSlice().shard(((config.dp, 1),))
|
self.slice = P.StridedSlice().shard(((config.dp, 1),))
|
||||||
self.not_equal = P.NotEqual().shard(((config.dp, 1), ()))
|
self.not_equal = P.NotEqual().shard(((config.dp, 1), ()))
|
||||||
|
@ -922,6 +983,10 @@ class PanguAlphaWithLoss(nn.Cell):
|
||||||
r"""
|
r"""
|
||||||
PanguAlphaWithLoss
|
PanguAlphaWithLoss
|
||||||
"""
|
"""
|
||||||
|
# input_ids [bs, seq_length+1]
|
||||||
|
# input_position [bs, seq_length] only available when eod_reset enabled
|
||||||
|
# attention_mask [bs, seq_length, seq_length] only available when eod-reset enabled
|
||||||
|
# Get input tokens [bs, seq_length]
|
||||||
tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1))
|
tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1))
|
||||||
|
|
||||||
if self.eod_reset:
|
if self.eod_reset:
|
||||||
|
@ -929,12 +994,14 @@ class PanguAlphaWithLoss(nn.Cell):
|
||||||
attention_mask = self.slice_mask(attention_mask, (0, 0, 0),
|
attention_mask = self.slice_mask(attention_mask, (0, 0, 0),
|
||||||
(self.batch_size, self.len, self.len),
|
(self.batch_size, self.len, self.len),
|
||||||
(1, 1, 1))
|
(1, 1, 1))
|
||||||
|
# Check whether there is padding in inputs
|
||||||
input_mask = F.cast(self.not_equal(tokens, self.eos_token),
|
input_mask = F.cast(self.not_equal(tokens, self.eos_token),
|
||||||
mstype.float32)
|
mstype.float32)
|
||||||
logits = self.network(tokens, input_mask, input_position, attention_mask)
|
logits = self.network(tokens, input_mask, input_position, attention_mask)
|
||||||
|
# Get label corresponding to input tokens
|
||||||
labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1),
|
labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1),
|
||||||
(1, 1))
|
(1, 1))
|
||||||
|
# Loss
|
||||||
output = self.loss(logits, labels, input_mask)
|
output = self.loss(logits, labels, input_mask)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
@ -50,13 +50,17 @@ class PANGUALPHAConfig:
|
||||||
self.embedding_size = embedding_size
|
self.embedding_size = embedding_size
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
# The expand ratio of feature size in FFN
|
||||||
self.expand_ratio = expand_ratio
|
self.expand_ratio = expand_ratio
|
||||||
|
# Use post-layernorm or pre-layernrom, default:pre-layernorm
|
||||||
self.post_layernorm_residual = post_layernorm_residual
|
self.post_layernorm_residual = post_layernorm_residual
|
||||||
self.dropout_rate = dropout_rate
|
self.dropout_rate = dropout_rate
|
||||||
self.compute_dtype = compute_dtype
|
self.compute_dtype = compute_dtype
|
||||||
|
# Whether use incremental inference
|
||||||
self.use_past = use_past
|
self.use_past = use_past
|
||||||
self.dp = data_parallel_num
|
self.dp = data_parallel_num
|
||||||
self.mp = model_parallel_num
|
self.mp = model_parallel_num
|
||||||
|
# Whether use self implemented layernorm
|
||||||
self.self_layernorm = self_layernorm
|
self.self_layernorm = self_layernorm
|
||||||
self.stage_num = stage_num
|
self.stage_num = stage_num
|
||||||
self.micro_size = micro_size
|
self.micro_size = micro_size
|
||||||
|
|
|
@ -45,6 +45,7 @@ def _clip_grad(clip_type, clip_value, grad):
|
||||||
if clip_type not in [0, 1]:
|
if clip_type not in [0, 1]:
|
||||||
return grad
|
return grad
|
||||||
dt = F.dtype(grad)
|
dt = F.dtype(grad)
|
||||||
|
# 0 for clip_by_value and 1 for clip_by_norm
|
||||||
if clip_type == 0:
|
if clip_type == 0:
|
||||||
new_grad = C.clip_by_value(
|
new_grad = C.clip_by_value(
|
||||||
grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||||
|
@ -107,12 +108,14 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
|
||||||
def construct(self, input_ids, input_position=None, attention_mask=None, layer_past=None, sens=None):
|
def construct(self, input_ids, input_position=None, attention_mask=None, layer_past=None, sens=None):
|
||||||
"""Defines the computation performed."""
|
"""Defines the computation performed."""
|
||||||
weights = self.weights
|
weights = self.weights
|
||||||
|
# Forward process
|
||||||
loss = self.network(input_ids, input_position, attention_mask)
|
loss = self.network(input_ids, input_position, attention_mask)
|
||||||
scaling_sens = self.scale_sense
|
scaling_sens = self.scale_sense
|
||||||
|
|
||||||
# alloc status and clear should be right before gradoperation
|
# alloc status and clear should be right before gradoperation
|
||||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
||||||
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
|
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
|
||||||
|
# Backward process using loss scale
|
||||||
grads = self.grad(self.network,
|
grads = self.grad(self.network,
|
||||||
weights)(input_ids,
|
weights)(input_ids,
|
||||||
input_position, attention_mask,
|
input_position, attention_mask,
|
||||||
|
@ -129,8 +132,11 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
|
||||||
grads = self.hyper_map(
|
grads = self.hyper_map(
|
||||||
F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE),
|
F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE),
|
||||||
grads)
|
grads)
|
||||||
|
# Check whether overflow
|
||||||
cond = self.get_overflow_status(status, grads)
|
cond = self.get_overflow_status(status, grads)
|
||||||
overflow = self.process_loss_scale(cond)
|
overflow = self.process_loss_scale(cond)
|
||||||
|
# If overflow, surpass weights update
|
||||||
|
# if not, update weights
|
||||||
if overflow:
|
if overflow:
|
||||||
succ = False
|
succ = False
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -70,7 +70,9 @@ class GlobalNorm(nn.Cell):
|
||||||
self.values.append(Tensor([self.group_size*1.0], mstype.float32))
|
self.values.append(Tensor([self.group_size*1.0], mstype.float32))
|
||||||
self.values = tuple(self.values)
|
self.values = tuple(self.values)
|
||||||
def construct(self, grads):
|
def construct(self, grads):
|
||||||
|
# Square sum of gradients for current rank
|
||||||
square_sum_dp = self.hyper_map(get_square_sum, grads, self.values)
|
square_sum_dp = self.hyper_map(get_square_sum, grads, self.values)
|
||||||
|
# Global square sum of gradients
|
||||||
global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp)))
|
global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp)))
|
||||||
return global_norms
|
return global_norms
|
||||||
|
|
||||||
|
|
|
@ -77,10 +77,12 @@ def run_train(args_opt):
|
||||||
The main training process.
|
The main training process.
|
||||||
"""
|
"""
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
# Set execution mode
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target="Ascend",
|
device_target="Ascend",
|
||||||
device_id=device_id)
|
device_id=device_id)
|
||||||
context.set_context(variable_memory_max_size="30GB")
|
context.set_context(variable_memory_max_size="30GB")
|
||||||
|
# Set parallel context
|
||||||
if args_opt.distribute == "true":
|
if args_opt.distribute == "true":
|
||||||
D.init()
|
D.init()
|
||||||
device_num = D.get_group_size()
|
device_num = D.get_group_size()
|
||||||
|
@ -102,6 +104,8 @@ def run_train(args_opt):
|
||||||
else:
|
else:
|
||||||
rank = 0
|
rank = 0
|
||||||
device_num = 1
|
device_num = 1
|
||||||
|
|
||||||
|
# Set model property
|
||||||
model_parallel_num = args_opt.tensor_model_parallel_num
|
model_parallel_num = args_opt.tensor_model_parallel_num
|
||||||
data_parallel_num = int(device_num / model_parallel_num)
|
data_parallel_num = int(device_num / model_parallel_num)
|
||||||
batch_size = args_opt.per_batch_size * device_num
|
batch_size = args_opt.per_batch_size * device_num
|
||||||
|
@ -124,18 +128,23 @@ def run_train(args_opt):
|
||||||
eod_reset=bool(args_opt.eod_reset),
|
eod_reset=bool(args_opt.eod_reset),
|
||||||
word_emb_dp=True)
|
word_emb_dp=True)
|
||||||
print("===config is: ", config, flush=True)
|
print("===config is: ", config, flush=True)
|
||||||
|
|
||||||
|
# Define network
|
||||||
pangu_alpha = PanguAlpha(config)
|
pangu_alpha = PanguAlpha(config)
|
||||||
loss = CrossEntropyLoss(config)
|
loss = CrossEntropyLoss(config)
|
||||||
pangu_alpha_with_loss = PanguAlphaWithLoss(config, pangu_alpha, loss)
|
pangu_alpha_with_loss = PanguAlphaWithLoss(config, pangu_alpha, loss)
|
||||||
pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss)
|
pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss)
|
||||||
|
|
||||||
print("=====args_opt is: ", args_opt, flush=True)
|
print("=====args_opt is: ", args_opt, flush=True)
|
||||||
|
|
||||||
|
# Warm-up and cosine decay learning rate
|
||||||
lr = LearningRate(learning_rate=args_opt.start_lr,
|
lr = LearningRate(learning_rate=args_opt.start_lr,
|
||||||
end_learning_rate=args_opt.end_lr,
|
end_learning_rate=args_opt.end_lr,
|
||||||
warmup_steps=args_opt.warmup_step,
|
warmup_steps=args_opt.warmup_step,
|
||||||
decay_steps=200000,
|
decay_steps=200000,
|
||||||
lr_scale=1)
|
lr_scale=1)
|
||||||
|
|
||||||
|
# Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest
|
||||||
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
|
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
|
||||||
params = pangu_alpha.trainable_params()
|
params = pangu_alpha.trainable_params()
|
||||||
decay_params = list(filter(decay_filter, params))
|
decay_params = list(filter(decay_filter, params))
|
||||||
|
@ -153,8 +162,10 @@ def run_train(args_opt):
|
||||||
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||||||
else:
|
else:
|
||||||
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
|
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
|
||||||
|
# Initial scaling sens
|
||||||
loss_scale_value = math.pow(2, 32)
|
loss_scale_value = math.pow(2, 32)
|
||||||
epoch_num = args_opt.epoch_size
|
epoch_num = args_opt.epoch_size
|
||||||
|
# Dataset loading mindrecord files
|
||||||
ds = create_dataset(config.batch_size, data_path=args_opt.data_url,
|
ds = create_dataset(config.batch_size, data_path=args_opt.data_url,
|
||||||
data_start_index=0, eod_reset=config.eod_reset,
|
data_start_index=0, eod_reset=config.eod_reset,
|
||||||
eod_id=args_opt.eod_id, device_num=device_num, rank=rank, epoch=epoch_num)
|
eod_id=args_opt.eod_id, device_num=device_num, rank=rank, epoch=epoch_num)
|
||||||
|
|
Loading…
Reference in New Issue