!19163 pangu alpha incremental inference

Merge pull request !19163 from wangjun/master_0630
This commit is contained in:
i-robot 2021-07-01 07:47:08 +00:00 committed by Gitee
commit d642e1de53
3 changed files with 460 additions and 99 deletions

View File

@ -111,7 +111,16 @@ def run_predict(args_opt):
# Compile network and obtain tensor layout for loading ckpt
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
predict_layout = model_predict.infer_predict_layout(inputs_np)
current_index = Tensor(np.array([0]), mstype.int32)
if config.use_past:
batch_valid_length = Tensor(np.array([0]), mstype.int32)
init_true = Tensor([True], mstype.bool_)
inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32)
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length)
_ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_true, batch_valid_length)
else:
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index)
print("======start load_distributed checkpoint", flush=True)
# For 2.6B and 13B models, the number of ckpt files is 512.
ckpt_name = 'filerted'
@ -123,7 +132,7 @@ def run_predict(args_opt):
print("================load param ok=================", flush=True)
from src.tokenization_jieba import JIEBATokenizer
from src.generate import generate
from src.generate import generate, generate_increment
# Define tokenizer
tokenizer = JIEBATokenizer(os.path.join(args_opt.tokenizer_path, 'vocab10.vocab'),
os.path.join(args_opt.tokenizer_path, 'vocab10.model'))
@ -134,7 +143,8 @@ def run_predict(args_opt):
start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token)
input_ids = np.array(start_sentence).reshape(1, -1)
# Call inference
output_ids = generate(model_predict, input_ids, opt)
generate_func = generate_increment if config.use_past else generate
output_ids = generate_func(model_predict, input_ids, opt)
# Decode output ids to sentence
output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist())
print('Output is:', output_samples, flush=True)

View File

@ -24,7 +24,7 @@ from mindspore.ops import operations as P
def generate(model, origin_inputs, config):
"""
TopK for text generation
Text generation
Inputs:
model: the model for inferencing
@ -59,7 +59,7 @@ def generate(model, origin_inputs, config):
while valid_length < target_length:
inputs = Tensor(input_ids, mstype.int32)
# Indicate the exact token position
current_index = valid_length - 1 if valid_lenght - 1 > 0 else 0
current_index = valid_length - 1 if valid_length - 1 > 0 else 0
current_index = Tensor([current_index], mstype.int32)
# Call a single inference
log_probs = model.predict(inputs, current_index)
@ -103,6 +103,10 @@ def generate(model, origin_inputs, config):
if p_args[target_index] == end_token or valid_length == target_length-1:
outputs = input_ids
break
# update frequency list
target = p_args[target_index]
frequency_list[0][target] = frequency_list[0][target] + 1
# Modify input_ids with newly generated token
input_ids[0][valid_length] = p_args[target_index]
valid_length += 1
@ -110,3 +114,117 @@ def generate(model, origin_inputs, config):
length = np.sum(outputs != 0)
outputs = outputs[0][:length]
return outputs
def generate_increment(model, origin_inputs, config):
"""
Text generation for incremental inference
Inputs:
model: the model for inferencing
origin_inputs: the original inputs based on which the model will continue writing
config: inference configurations
Returns:
outputs: the ids for the generated text
"""
# Get configurations for inference
frequency_penalty = config.frequency_penalty
presence_penalty = config.presence_penalty
top_p = config.top_p
top_k_num = config.top_k_num
max_generate_length = config.max_generate_length
seq_length = config.seq_length
end_token = config.end_token
_, valid_length = origin_inputs.shape
# Init outputs with original inputs
outputs = [origin_inputs[0][i] for i in range(valid_length)]
# If target length exceeds seq_length, use seq_length instead
target_length = valid_length + max_generate_length
target_length = seq_length if target_length > seq_length else target_length
# A list of the frequency of each token
frequency_list = np.array([[0 for _ in range(config.vocab_size)]])
pad_length = seq_length - origin_inputs.shape[-1]
# Pad original inputs to seq_length
input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
print("input_ids is ", input_ids)
# Indicate the exact token position
current_index = valid_length - 1 if valid_length - 1 > 0 else 0
current_index = Tensor(np.array([current_index]), mstype.int32)
batch_valid_length = Tensor(np.array([current_index]), mstype.int32)
# For first graph, not_init should be false
init_true = Tensor([True], mstype.bool_)
init_false = Tensor([False], mstype.bool_)
init = init_false
# Claim the first graph
model.predict_network.add_flags_recursive(is_first_iteration=True)
# Call a single inference with input size of (bs, seq_length)
logits = model.predict(Tensor(input_ids, mstype.int32), init, batch_valid_length, current_index)
# Claim the second graph and set not_init to true
init = init_true
model.predict_network.add_flags_recursive(is_first_iteration=False)
# A single loop generates one token, loop until reaching target seq_length or generating eod token
while valid_length < target_length:
# Reshape the output logits
logits = logits.asnumpy()
log_probs = logits.reshape(1, vocab_size)
# Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results
log_probs = log_probs.asnumpy().reshape(1, config.vocab_size)
log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty
# Convert the log_probs to probability
logits = P.Pow()(10, Tensor(log_probs_revised, mstype.float32))
# If top_p is less than 1.0, use top_p sampling
if top_p < 1.0:
# Only consider the 5000 largest logits to reduce computation
sorted_logits, index = P.TopK(sorted=True)(logits, 5000)
cumsum_logits = P.CumSum()(sorted_logits, 1)
cumsum_logits = cumsum_logits.asnumpy()[0]
index = index.asnumpy()[0]
sorted_logits = sorted_logits.asnumpy()[0]
top_p_num = sum(cumsum_logits > top_p)
# In case the probability is smooth, the sum of 5000 largest probabilities are not large enough
if top_p_num == 0:
top_p_num = 5000
# Get the corresponding probs and indices
probs = sorted_logits[:top_p_num]
p_args = index[:top_p_num]
p = probs / sum(probs)
# if top_p is set to 1.0, use top_k sampling
else:
# Get the corresponding probs and indices
probs, p_args = P.TopK(sorted=True)(logits, top_k_num)
probs = probs.asnumpy()[0]
p_args = p_args.asnumpy()[0]
# Avoid rounding error
if sum(probs) == 0:
probs = np.array([1 / top_k_num for _ in range(top_k_num)])
p = probs / sum(probs)
# Random select a token as final output for this round
target_index = np.random.choice(len(p), p=p)
# Stop judgment
if p_args[target_index] == end_token or valid_length == target_length-1:
break
# Update frequency list
target = p_args[target_index]
frequency_list[0][target] = frequency_list[0][target] + 1
valid_length += 1
batch_valid_length = Tensor(np.array([valid_length - 1]), mstype.int32)
current_index = Tensor(np.array([0]), mstype.int32)
input_id = Tensor([[target]], mstype.int32)
# Update outputs with current generated token
outputs.append(int(target))
# Call a single inference with input size of (bs, 1)
logits = model.predict(input_id, current_index, init, batch_valid_length)
# Return valid outputs out of padded outputs
return np.array(outputs)

View File

@ -420,7 +420,28 @@ class Attention(nn.Cell):
self.dense3.matmul.shard(((config.dp, 1), (config.mp, 1)))
self.dense3.bias_add.shard(((config.dp, config.mp), (config.mp,)))
def construct(self, x, attention_mask, layer_past=None):
self.is_first_iteration = True
self.dtype = config.compute_dtype
self.use_past = config.use_past
if self.use_past:
# operators used for state reuse
seq_range = np.arange(config.seq_length).reshape(1, 1, -1)
self.range = Tensor(np.tile(seq_range, (config.batch_size, 1, 1)), mstype.int32)
self.seq_length = config.seq_length
self.attention_mask = Tensor(np.tril(np.ones(shape=(self.seq_length, self.seq_length))), mstype.int32)
self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
self.expand_dims = P.ExpandDims().shard(((1, 1, 1),))
self.tensor_le = P.LessEqual().shard(((1, 1, 1), (1, 1, 1)))
self.add = P.TensorAdd().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
self.equal = P.Equal().shard(((1, 1, 1), (1, 1, 1)))
self.sub1 = P.Sub().shard(((1,), ()))
self.tile = P.Tile().shard(((1, 1, 1, 1),))
self.less = P.Less().shard(((1, 1, 1), (1, 1, 1)))
self.mul1 = P.Mul().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
def construct(self, x, attention_mask, key_past=None, value_past=None, batch_valid_length=None):
"""
self-attention
@ -428,7 +449,9 @@ class Attention(nn.Cell):
x: output of previous layer
attention_mask: the attention mask matrix with shape (batch_size, 1,
seq_length, seq_length)
layer_past: the previous feature map
key_past: previous saved key state
value_past: previous saved value state
batch_valid_length: the valid input seq_length without padding
Returns:
output: Tensor, the output logit of this layer
@ -458,12 +481,44 @@ class Attention(nn.Cell):
value,
(-1, original_shape[1], self.n_head, self.size_per_head)),
(0, 2, 1, 3))
# key and value for current token(s)
key_present = key
value_present = value
if self.use_past:
past_value = layer_past[1]
past_key = self.transpose(layer_past[0], (0, 1, 3, 2))
key = self.concat_k((past_key, key))
value = self.concat_v(past_value, value)
layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value])
# The first graph with the input size of (bs, seq_length)
if self.is_first_iteration:
# Get the valid input length without padding
valid_length_vector = F.cast(self.less(self.range, batch_valid_length), self.dtype)
# Cover the key and value numbers corresponding to the padding position
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
# The second graph with the inpus size of (bs, 1)
# the shape of query is (bs, num_heads, 1, size_per_head)
# the shape of key is (bs, num_heads, size_per_head, 1)
# the shape of value is (bs, num_heads, 1, size_per_head)
else:
# Get the current token position index
valid_length = self.reducesum(F.cast(self.not_equal(self.slice(key_past, (0, 0, 0, 0),
(F.shape(x)[0], 1, 1, self.seq_length),
(1, 1, 1, 1)),
0), mstype.float32), (1, 2, 3))
valid_length = F.reshape(valid_length, (-1, 1, 1))
valid_length_vector = F.cast(self.equal(valid_length, self.range), self.dtype)
# Pad the key and value to seq_length with only the position index not zero
current_key = self.mul1(self.tile(key, (1, 1, 1, self.seq_length)),
self.expand_dims(valid_length_vector, 2))
current_value = self.mul1(self.tile(value, (1, 1, self.seq_length, 1)),
self.expand_dims(valid_length_vector, 3))
# Concat the previous saved state and current state
key = self.add(key_past, current_key)
value = self.add(value_past, current_value)
# Update key_present and value_present for state update
key_present = key
value_present = value
attention_mask = F.reshape(self.attention_mask, (self.seq_length, self.seq_length, 1, 1))
layer_present = (key_present, value_present)
# Self-attention considering attention mask
attention = self._attn(query, key, value, attention_mask)
# [bs, seq_length, embedding_size]
@ -523,7 +578,7 @@ class Attention(nn.Cell):
query = query / F.cast(self.coeff, F.dtype(query))
key = key / F.cast(self.coeff, F.dtype(key))
# Attention score [bs, num_heads, seq_length, seq_length]
# Attention score [bs, num_heads, seq_length_q, seq_length_k]
score = self.batch_matmul(query, key)
# Normalize after query and key MatMul, default on
if self.scale:
@ -533,6 +588,22 @@ class Attention(nn.Cell):
ori_dtype = P.DType()(score)
score = P.Cast()(score, mstype.float32)
# for input size of (bs, 1) namely the second graph, the shape of attention_mask matrix should be
# (bs, 1, 1, seq_length)
if self.use_past and not self.is_first_iteration:
# Calculate the current total token
current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0),
(F.shape(query)[0], 1, 1, self.seq_length),
(1, 1, 1, 1)),
0), mstype.float32), (1, 2, 3))
# Get the precise position index
index = self.sub1(F.cast(current_index, mstype.int32), 1)
index = F.reshape(index, (-1, 1, 1))
# Calculate the attention_mask matrix via the position index
attention_mask = F.cast(self.tensor_le(self.range, index), mstype.int32)
attention_mask = self.expand_dims(attention_mask, 2)
# Minus 10000 for the position where masked to exclude them from softmax
multiplu_out = self.sub(
P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
@ -550,27 +621,28 @@ class Attention(nn.Cell):
attention_probs = F.reshape(attention_probs, shape)
attention_probs = self.prob_dropout(attention_probs)
# Weighted sum output [bs, num_heads, seq_length, size_per_head]
# Weighted sum output [bs, num_heads, seq_length_q, size_per_head]
weighted_values = self.batch_matmul(attention_probs, value)
return weighted_values
class Block(nn.Cell):
class Decoder(nn.Cell):
"""
The basic block of PanguAlpha network
The basic decoder structure of PanguAlpha network
Args:
config(PanguAlphaConfig): the config of network
layer_idx: current layer index
Inputs:
x: the output of previous layer(input_ids for the first layer)
attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
layer_past: the previous feature map
init_reset: whether reset the previous state
batch_valid_length: the valid input seq_length without padding
Returns:
output: Tensor, the output logit of this layer
layer_present: Tensor, the feature map of current layer
"""
def __init__(self, config, layer_idx):
super(Block, self).__init__()
super(Decoder, self).__init__()
scale = 1 / math.sqrt(2.0 * config.num_layers)
if config.self_layernorm:
@ -593,16 +665,43 @@ class Block(nn.Cell):
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
# Last activation of this layer will be saved for recompute in backward process
self.dtype = config.compute_dtype
self.use_past = config.use_past
if self.use_past:
# operator used for state reuse
self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
size_per_head = int(config.embedding_size / config.num_heads)
self.key_shape = (config.batch_size, config.num_heads, size_per_head, config.seq_length)
self.value_shape = (config.batch_size, config.num_heads, config.seq_length, size_per_head)
# parameters saving key and value states
self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
self.tile = P.Tile().shard(((1, 1),))
self.mul = P.Mul().shard(((1, 1, 1, 1), (1,)))
self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
def construct(self, x, input_mask, layer_past=None):
def construct(self, x, input_mask, init_reset=True, batch_valid_length=None):
r"""
The forward process of the block.
"""
# [bs, seq_length, embedding_size]
input_x = self.layernorm1(x)
input_x = F.cast(input_x, self.dtype)
# indicate whether reset saved states
key_reset = None
value_reset = None
if self.use_past:
# reset states, init_reset True for reuse and False for reset
key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
# add dependency for desired execution order
input_x = F.depend(input_x, key_reset)
input_x = F.depend(input_x, value_reset)
attention, layer_present = self.attention(input_x, input_mask,
layer_past)
self.key_past, self.value_past, batch_valid_length)
# For post-layernorm the inputs for residual path are output of self-attention and output of layernorm
if self.post_layernorm_residual:
x = self.add(input_x, attention)
@ -613,6 +712,22 @@ class Block(nn.Cell):
output_x = self.layernorm2(x)
output_x = F.cast(output_x, self.dtype)
mlp_logit = self.output(output_x)
value_update = None
key_update = None
if self.use_past:
# current key and value
key_present, value_present = layer_present
# update key and value calculated this step
key_update = self.assign(self.key_past, key_present)
value_update = self.assign(self.value_past, value_present)
# add dependency for desired execution order
key_update = F.depend(key_update, key_reset)
value_update = F.depend(value_update, value_reset)
# add dependency for desired execution order
mlp_logit = F.depend(mlp_logit, value_update)
mlp_logit = F.depend(mlp_logit, key_update)
if self.post_layernorm_residual:
output = self.add(output_x, mlp_logit)
else:
@ -635,9 +750,14 @@ class PanguAlpha_EmbeddingPipeLine(nn.Cell):
self.dropout = Dropout(1 - config.dropout_rate)
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
self.use_past = config.use_past
self.is_first_iteration = True
def construct(self, input_ids, table, input_position):
def construct(self, input_ids, table, input_position, valid_index=None):
input_embedding = self.word_embedding(input_ids, table)
if self.use_past and not self.is_first_iteration:
_, seq_length = F.shape(input_ids)
input_position = valid_index.view(1, seq_length)
position_embedding = self.position_embedding(input_position)
hidden_states = self.add(input_embedding, position_embedding)
hidden_states = self.dropout(hidden_states)
@ -662,7 +782,7 @@ class QueryLayerAttention(Attention):
r"""
Self-Attention module using input query vector.
"""
def construct(self, x, query_hidden_state, attention_mask, layer_past=None):
def construct(self, x, query_hidden_state, attention_mask, key_past=None, value_past=None, batch_valid_length=None):
original_shape = F.shape(x)
x = F.reshape(x, (-1, original_shape[-1]))
query_hidden_state = F.reshape(query_hidden_state, (-1, original_shape[-1]))
@ -684,12 +804,42 @@ class QueryLayerAttention(Attention):
value,
(-1, original_shape[1], self.n_head, self.size_per_head)),
(0, 2, 1, 3))
key_present = key
value_present = value
if self.use_past:
past_value = layer_past[1]
past_key = self.transpose(layer_past[0], (0, 1, 3, 2))
key = self.concat_k((past_key, key))
value = self.concat_v(past_value, value)
layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value])
# The first graph with the input size of (bs, seq_length)
if self.is_first_iteration:
# Get the valid input length without padding
valid_length_vector = F.cast(self.less(self.range, batch_valid_length), self.dtype)
# Cover the key and value numbers corresponding to the padding position
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
# The second graph with the inpus size of (bs, 1)
# the shape of query is (bs, num_heads, 1, size_per_head)
# the shape of key is (bs, num_heads, size_per_head, 1)
# the shape of value is (bs, num_heads, 1, size_per_head)
else:
# Get the current token position index
valid_length = self.reducesum(F.cast(self.not_equal(self.slice(key_past, (0, 0, 0, 0),
(F.shape(x)[0], 1, 1, self.seq_length),
(1, 1, 1, 1)),
0), mstype.float32), (1, 2, 3))
valid_length = F.reshape(valid_length, (-1, 1, 1))
valid_length_vector = F.cast(self.equal(valid_length, self.range), self.dtype)
# Pad the key and value to seq_length with only the position index not zero
current_key = self.mul1(self.tile(key, (1, 1, 1, self.seq_length)),
self.expand_dims(valid_length_vector, 2))
current_value = self.mul1(self.tile(value, (1, 1, self.seq_length, 1)),
self.expand_dims(valid_length_vector, 3))
# Concat the previous saved state and current state
key = self.add(key_past, current_key)
value = self.add(value_past, current_value)
# Update key_present and value_present for state update
key_present = key
value_present = value
attention_mask = F.reshape(self.attention_mask, (self.seq_length, self.seq_length, 1, 1))
layer_present = (key_present, value_present)
attention = self._attn(query, key, value, attention_mask)
attention_merge = self.merge_heads(attention)
output = self.projection(attention_merge)
@ -715,18 +865,48 @@ class QueryLayer(nn.Cell):
self.post_layernorm_residual = config.post_layernorm_residual
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
self.dtype = config.compute_dtype
self.use_past = config.use_past
if self.use_past:
# operator used for state reuse
self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
size_per_head = int(config.embedding_size / config.num_heads)
self.key_shape = (config.batch_size, config.num_heads, size_per_head, config.seq_length)
self.value_shape = (config.batch_size, config.num_heads, config.seq_length, size_per_head)
# parameters saving key and value states
self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
self.tile = P.Tile().shard(((1, 1),))
self.mul = P.Mul().shard(((1, 1, 1, 1), (1,)))
self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
def construct(self, x, query_hidden_state, input_mask, layer_past=None):
def construct(self, x, query_hidden_state, input_mask, init_reset=True, batch_valid_length=None):
r"""
Query Layer shares a similar structure with normal layer block
except that it is not a traditional self-attention.
"""
input_x = self.layernorm1(x)
input_x = F.cast(input_x, self.dtype)
# indicate whether reset saved states
key_reset = None
value_reset = None
if self.use_past:
# reset states, init_reset True for reuse and False for reset
key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
# add dependency for desired execution order
input_x = F.depend(input_x, key_reset)
input_x = F.depend(input_x, value_reset)
attention, layer_present = self.attention(input_x,
query_hidden_state,
input_mask,
layer_past)
self.key_past,
self.value_past,
batch_valid_length)
if self.post_layernorm_residual:
x = self.add(input_x, attention)
else:
@ -735,28 +915,46 @@ class QueryLayer(nn.Cell):
output_x = self.layernorm2(x)
output_x = F.cast(output_x, self.dtype)
mlp_logit = self.output(output_x)
value_update = None
key_update = None
if self.use_past:
# current key and value
key_present, value_present = layer_present
# update key and value calculated this step
key_update = self.assign(self.key_past, key_present)
value_update = self.assign(self.value_past, value_present)
# add dependency for desired execution order
key_update = F.depend(key_update, key_reset)
value_update = F.depend(value_update, value_reset)
# add dependency for desired execution order
mlp_logit = F.depend(mlp_logit, value_update)
mlp_logit = F.depend(mlp_logit, key_update)
if self.post_layernorm_residual:
output = self.add(output_x, mlp_logit)
else:
output = self.add(x, mlp_logit)
return output, layer_present
class PanguAlpha_Model(nn.Cell):
class Embedding(nn.Cell):
"""
The backbone of PanguAlpha network
Input embedding, i.e., word embedding and position embedding
Args:
config(PanguAlphaConfig): the config of network
Inputs:
input_ids: the tokenized inputs with datatype int32
input_mask: the mask indicating whether each position is a valid input
layer_past: the previous feature map
Returns:
output_state: Tensor, the output logit of backbone
present_layer: Tensor, the current feature map
embedding_table: Tensor, the embedding table for the vocabulary
input_position: the position index of each token
attention_mask: the attention_mask attention for self-attention module
valid_index: only used in incremental inference, the position index of current token
outputs:
hidden_states: Tensor, input embeddings
attention_mask: Tensor, attention_mask matrix
embedding_table: Tensor, embedding_table with shape of (vocab_size, embedding_size)
"""
def __init__(self, config):
super(PanguAlpha_Model, self).__init__()
super(Embedding, self).__init__()
self.get_attention_mask = AttentionMask(config)
# Word embedding
self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1)
@ -780,6 +978,59 @@ class PanguAlpha_Model(nn.Cell):
self.position_embedding.embedding_table.parallel_optimizer = False
self.position_embedding.gather.shard(((1, 1), (config.dp,)))
self.position_embedding.expand.shard(((config.dp, 1),))
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
self.dropout = Dropout(1 - config.dropout_rate)
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
self.eod_reset = config.eod_reset
self.use_past = config.use_past
self.is_first_iteration = True
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, valid_index=None):
"""
Calculate input embeddings via input token ids and input position
"""
# Word embedding
input_embedding, embedding_table = self.word_embedding(input_ids)
# If eod_reset disabled, there will be only one input from the dataset, i.e., input_ids
# and the corresponding input_position and attention_mask will be derived from it.
if not self.eod_reset:
batch_size, seq_length = F.shape(input_ids)
attention_mask = self.get_attention_mask(input_mask)
if self.use_past and not self.is_first_iteration:
input_position = valid_index.view(1, seq_length)
else:
input_position = F.tuple_to_array(F.make_range(seq_length))
input_position = P.Tile()(input_position, (batch_size, 1))
position_embedding = self.position_embedding(input_position)
# Input features [bs, seq_length, embedding_size]
hidden_states = self.add(input_embedding, position_embedding)
hidden_states = self.dropout(hidden_states)
hidden_states = P.Cast()(hidden_states, mstype.float16)
attention_mask = self.expand_dims(attention_mask, 1)
return hidden_states, attention_mask, embedding_table
class PanguAlpha_Model(nn.Cell):
"""
The backbone of PanguAlpha network
Args:
config(PanguAlphaConfig): the config of network
Inputs:
input_ids: the tokenized inputs with datatype int32
input_mask: the mask indicating whether each position is a valid input
input_position: the position index of each token
attention_mask: the attention_mask attention for self-attention module
init_reset: whether reset saved key and value states
batch_valid_length: the valid input sequence length without padding
Returns:
output_state: Tensor, the output logit of backbone
embedding_table: Tensor, the embedding table for the vocabulary
"""
def __init__(self, config):
super(PanguAlpha_Model, self).__init__()
self.embedding = Embedding(config)
self.blocks = nn.CellList()
# Total fusion groups for HCCL operators. Specifically, the same tyep HCCL operators in same group will be fused.
fusion_group_num = 4
@ -794,9 +1045,9 @@ class PanguAlpha_Model(nn.Cell):
print("After setting the layer is:", num_layers, flush=True)
for i in range(num_layers):
per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
# Each layer will be recomputed in the backward process. The output activation of each layer will be saved,
# in other words, in backward process each block will be almost totally recomputed.
per_block = Decoder(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
# Each layer will be remoputed in the backward process. The output activation of each layer will be saved,
# in other words, in backward process each block will be almosttotally recomputed.
if config.use_recompute:
per_block.recompute()
self.blocks.append(per_block)
@ -813,13 +1064,7 @@ class PanguAlpha_Model(nn.Cell):
self.layernorm.beta.parallel_optimizer = False
self.use_past = config.use_past
self.past = tuple([None] * config.num_layers)
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
self.dtype = config.compute_dtype
self.dropout = Dropout(1 - config.dropout_rate)
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
self.eod_reset = config.eod_reset
# If top_query_attention enabled, the input_position representing the position ids will be used as the index
# for a query embedding table to obtain top query hidden states, together with the previous outputs of normal
# self-attention layers, a new attention layer will be attached to the output of the model
@ -853,33 +1098,18 @@ class PanguAlpha_Model(nn.Cell):
self.use_top_query_attention = config.use_top_query_attention
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, layer_past=None):
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None,
init_reset=True, batch_valid_length=None):
"""PanguAlpha model"""
if not self.use_past:
layer_past = self.past
# embedding for input_ids and the lower triangle like attention_mask matrix
hidden_states, attention_mask, embedding_table = self.embedding(input_ids, input_mask,
input_position, attention_mask,
batch_valid_length)
# Word embedding
input_embedding, embedding_table = self.word_embedding(input_ids)
# If eod_reset disabled, there will be only one input from the dataset, i.e., input_ids
# and the corresponding input_position and attention_mask will be derived from it.
if not self.eod_reset:
batch_size, seq_length = F.shape(input_ids)
input_position = F.tuple_to_array(F.make_range(seq_length))
input_position = P.Tile()(input_position, (batch_size, 1))
attention_mask = self.get_attention_mask(input_mask)
position_embedding = self.position_embedding(input_position)
# Input features [bs, seq_length, embedding_size]
hidden_states = self.add(input_embedding, position_embedding)
hidden_states = self.dropout(hidden_states)
hidden_states = P.Cast()(hidden_states, mstype.float16)
attention_mask = self.expand_dims(attention_mask, 1)
present_layer = ()
# Loop through each self-attention layer
for i in range(self.num_layers):
hidden_states, present = self.blocks[i](hidden_states,
attention_mask, layer_past)
present_layer = present_layer + (present,)
hidden_states, _ = self.blocks[i](hidden_states,
attention_mask, init_reset, batch_valid_length)
output_state = self.layernorm(hidden_states)
output_state = F.cast(output_state, self.dtype)
@ -887,10 +1117,9 @@ class PanguAlpha_Model(nn.Cell):
# Top query attention layer
if self.use_top_query_attention:
top_query_hidden_states = self.top_query_embedding(input_position)
output_state, present = self.top_query_layer(output_state, top_query_hidden_states,
attention_mask, layer_past)
present_layer = present_layer + (present,)
return output_state, present_layer, embedding_table
output_state, _ = self.top_query_layer(output_state, top_query_hidden_states,
attention_mask, init_reset, batch_valid_length)
return output_state, embedding_table
class PanguAlpha_ModelPipeline(nn.Cell):
"""
@ -922,7 +1151,7 @@ class PanguAlpha_ModelPipeline(nn.Cell):
self.top_query_embedding.pipeline_stage = i * config.stage_num // config.num_layers
per_block = QueryLayer(config).set_comm_fusion(2)
else:
per_block = Block(config, i + 1).set_comm_fusion(2)
per_block = Decoder(config, i + 1).set_comm_fusion(2)
per_block.pipeline_stage = i * config.stage_num // config.num_layers
per_block.recompute()
self.blocks.append(per_block)
@ -936,27 +1165,23 @@ class PanguAlpha_ModelPipeline(nn.Cell):
self.layernorm.set_comm_fusion(2)
self.layernorm.pipeline_stage = config.stage_num - 1
self.use_past = config.use_past
self.past = tuple([None] * config.num_layers)
self.dtype = config.compute_dtype
self.num_layers = config.num_layers
def construct(self, input_ids, input_mask, table, input_position, attention_mask, layer_past=None):
def construct(self, input_ids, input_mask, table, input_position, attention_mask,
init_reset=True, batch_valid_length=None):
"""PanguAlpha model"""
if not self.use_past:
layer_past = self.past
hidden_states = self.pangu_alpha_embedding(input_ids, table, input_position)
attention_mask = self.pangu_alpha_mask(input_mask, attention_mask)
present_layer = ()
for i in range(self.num_layers-1):
hidden_states, present = self.blocks[i](hidden_states,
attention_mask, layer_past)
present_layer = present_layer + (present,)
hidden_states, _ = self.blocks[i](hidden_states,
attention_mask, init_reset, batch_valid_length)
top_query_hidden_states = self.top_query_embedding(input_position)
hidden_states, present = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states,
attention_mask, layer_past)
present_layer = present_layer + (present,)
hidden_states, present_layer = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states,
attention_mask, init_reset, batch_valid_length)
output_state = self.layernorm(hidden_states)
output_state = F.cast(output_state, self.dtype)
return output_state, present_layer
@ -998,7 +1223,10 @@ class PanguAlpha(nn.Cell):
Inputs:
input_ids: the tokenized inputs
input_mask: the mask indicating whether each position is a valid input
past: the previous feature map
input_position: the position index of each token
attention_mask: the attention_mask attention for self-attention module
init_reset: whether reset saved key and value states
batch_valid_length: the valid input sequence length without padding
Returns:
logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
"""
@ -1009,9 +1237,10 @@ class PanguAlpha(nn.Cell):
# Network head to get logits over vocabulary
self.head = PanguAlpha_Head(config)
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, past=None):
output_states, _, embedding_table = self.backbone(
input_ids, input_mask, input_position, attention_mask, past)
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None,
init_reset=True, batch_valid_length=None):
output_states, embedding_table = self.backbone(
input_ids, input_mask, input_position, attention_mask, init_reset, batch_valid_length)
logits = self.head(output_states, embedding_table)
return logits
@ -1039,9 +1268,10 @@ class PanguAlphaPipeline(nn.Cell):
self.embedding_table.add_pipeline_stage(self.backbone.blocks[0].pipeline_stage)
self.embedding_table.add_pipeline_stage(self.head.pipeline_stage)
def construct(self, input_ids, input_mask, input_position, attention_mask, past=None):
def construct(self, input_ids, input_mask, input_position, attention_mask,
init_reset=True, batch_valid_length=None):
output_states, _ = self.backbone(input_ids, input_mask, self.embedding_table,
input_position, attention_mask, past)
input_position, attention_mask, init_reset, batch_valid_length)
logits = self.head(output_states, self.embedding_table)
return logits
@ -1213,24 +1443,27 @@ class EvalNet(nn.Cell):
Inputs:
input_ids: the tokenized inpus
current_index: the index of current token
init_reset: whether reset saved states
Returns:
outputs: Tensor, corresponding output for different tasks
"""
def __init__(self, backbone, generate=False):
def __init__(self, backbone, generate=False, pad_token=6):
super(EvalNet, self).__init__(auto_prefix=False)
self.backbone = backbone
self.pad_token = pad_token
self.argmax = P.Argmax()
self.generate = generate
self.topk = P.TopK(sorted=True).shard(((1, 1),))
self.gather = P.GatherV2().shard(((1, 1, 1), (1,)))
self.log_softmax = P.LogSoftmax().shard(((1, 1),))
self.gather = P.GatherV2().shard(((1, 1), (1,)))
self.log_softmax = P.LogSoftmax().shard(((1, 1, 1),))
def construct(self, input_ids, current_index):
def construct(self, input_ids, current_index, init_reset=True, batch_valid_length=None):
"""evaluation net"""
input_mask = F.cast(F.not_equal(input_ids, 0), mstype.float32)
input_mask = F.cast(F.not_equal(input_ids, self.pad_token), mstype.float32)
logits = self.backbone(input_ids, input_mask)
bs, seq_length = F.shape(input_ids)
logits = F.reshape(logits, (bs, seq_length, -1))
logits = self.gather(logits, current_index, 1)
index = current_index.view(1,)
logits = self.gather(logits, index, 0)
bs, _ = F.shape(input_ids)
logits = logits.view(bs, 1, -1)
log_probs = self.log_softmax(logits)
return log_probs