!19163 pangu alpha incremental inference
Merge pull request !19163 from wangjun/master_0630
This commit is contained in:
commit
d642e1de53
|
@ -111,7 +111,16 @@ def run_predict(args_opt):
|
|||
|
||||
# Compile network and obtain tensor layout for loading ckpt
|
||||
inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
|
||||
predict_layout = model_predict.infer_predict_layout(inputs_np)
|
||||
current_index = Tensor(np.array([0]), mstype.int32)
|
||||
|
||||
if config.use_past:
|
||||
batch_valid_length = Tensor(np.array([0]), mstype.int32)
|
||||
init_true = Tensor([True], mstype.bool_)
|
||||
inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32)
|
||||
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length)
|
||||
_ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_true, batch_valid_length)
|
||||
else:
|
||||
predict_layout = model_predict.infer_predict_layout(inputs_np, current_index)
|
||||
print("======start load_distributed checkpoint", flush=True)
|
||||
# For 2.6B and 13B models, the number of ckpt files is 512.
|
||||
ckpt_name = 'filerted'
|
||||
|
@ -123,7 +132,7 @@ def run_predict(args_opt):
|
|||
print("================load param ok=================", flush=True)
|
||||
|
||||
from src.tokenization_jieba import JIEBATokenizer
|
||||
from src.generate import generate
|
||||
from src.generate import generate, generate_increment
|
||||
# Define tokenizer
|
||||
tokenizer = JIEBATokenizer(os.path.join(args_opt.tokenizer_path, 'vocab10.vocab'),
|
||||
os.path.join(args_opt.tokenizer_path, 'vocab10.model'))
|
||||
|
@ -134,7 +143,8 @@ def run_predict(args_opt):
|
|||
start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token)
|
||||
input_ids = np.array(start_sentence).reshape(1, -1)
|
||||
# Call inference
|
||||
output_ids = generate(model_predict, input_ids, opt)
|
||||
generate_func = generate_increment if config.use_past else generate
|
||||
output_ids = generate_func(model_predict, input_ids, opt)
|
||||
# Decode output ids to sentence
|
||||
output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist())
|
||||
print('Output is:', output_samples, flush=True)
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.ops import operations as P
|
|||
|
||||
def generate(model, origin_inputs, config):
|
||||
"""
|
||||
TopK for text generation
|
||||
Text generation
|
||||
|
||||
Inputs:
|
||||
model: the model for inferencing
|
||||
|
@ -59,7 +59,7 @@ def generate(model, origin_inputs, config):
|
|||
while valid_length < target_length:
|
||||
inputs = Tensor(input_ids, mstype.int32)
|
||||
# Indicate the exact token position
|
||||
current_index = valid_length - 1 if valid_lenght - 1 > 0 else 0
|
||||
current_index = valid_length - 1 if valid_length - 1 > 0 else 0
|
||||
current_index = Tensor([current_index], mstype.int32)
|
||||
# Call a single inference
|
||||
log_probs = model.predict(inputs, current_index)
|
||||
|
@ -103,6 +103,10 @@ def generate(model, origin_inputs, config):
|
|||
if p_args[target_index] == end_token or valid_length == target_length-1:
|
||||
outputs = input_ids
|
||||
break
|
||||
|
||||
# update frequency list
|
||||
target = p_args[target_index]
|
||||
frequency_list[0][target] = frequency_list[0][target] + 1
|
||||
# Modify input_ids with newly generated token
|
||||
input_ids[0][valid_length] = p_args[target_index]
|
||||
valid_length += 1
|
||||
|
@ -110,3 +114,117 @@ def generate(model, origin_inputs, config):
|
|||
length = np.sum(outputs != 0)
|
||||
outputs = outputs[0][:length]
|
||||
return outputs
|
||||
|
||||
def generate_increment(model, origin_inputs, config):
|
||||
"""
|
||||
Text generation for incremental inference
|
||||
|
||||
Inputs:
|
||||
model: the model for inferencing
|
||||
origin_inputs: the original inputs based on which the model will continue writing
|
||||
config: inference configurations
|
||||
|
||||
Returns:
|
||||
outputs: the ids for the generated text
|
||||
"""
|
||||
# Get configurations for inference
|
||||
frequency_penalty = config.frequency_penalty
|
||||
presence_penalty = config.presence_penalty
|
||||
top_p = config.top_p
|
||||
top_k_num = config.top_k_num
|
||||
max_generate_length = config.max_generate_length
|
||||
seq_length = config.seq_length
|
||||
end_token = config.end_token
|
||||
|
||||
_, valid_length = origin_inputs.shape
|
||||
# Init outputs with original inputs
|
||||
outputs = [origin_inputs[0][i] for i in range(valid_length)]
|
||||
# If target length exceeds seq_length, use seq_length instead
|
||||
target_length = valid_length + max_generate_length
|
||||
target_length = seq_length if target_length > seq_length else target_length
|
||||
|
||||
# A list of the frequency of each token
|
||||
frequency_list = np.array([[0 for _ in range(config.vocab_size)]])
|
||||
pad_length = seq_length - origin_inputs.shape[-1]
|
||||
# Pad original inputs to seq_length
|
||||
input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
|
||||
print("input_ids is ", input_ids)
|
||||
|
||||
# Indicate the exact token position
|
||||
current_index = valid_length - 1 if valid_length - 1 > 0 else 0
|
||||
current_index = Tensor(np.array([current_index]), mstype.int32)
|
||||
batch_valid_length = Tensor(np.array([current_index]), mstype.int32)
|
||||
# For first graph, not_init should be false
|
||||
init_true = Tensor([True], mstype.bool_)
|
||||
init_false = Tensor([False], mstype.bool_)
|
||||
init = init_false
|
||||
# Claim the first graph
|
||||
model.predict_network.add_flags_recursive(is_first_iteration=True)
|
||||
# Call a single inference with input size of (bs, seq_length)
|
||||
logits = model.predict(Tensor(input_ids, mstype.int32), init, batch_valid_length, current_index)
|
||||
|
||||
# Claim the second graph and set not_init to true
|
||||
init = init_true
|
||||
model.predict_network.add_flags_recursive(is_first_iteration=False)
|
||||
|
||||
# A single loop generates one token, loop until reaching target seq_length or generating eod token
|
||||
while valid_length < target_length:
|
||||
# Reshape the output logits
|
||||
logits = logits.asnumpy()
|
||||
log_probs = logits.reshape(1, vocab_size)
|
||||
|
||||
# Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results
|
||||
log_probs = log_probs.asnumpy().reshape(1, config.vocab_size)
|
||||
log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty
|
||||
|
||||
# Convert the log_probs to probability
|
||||
logits = P.Pow()(10, Tensor(log_probs_revised, mstype.float32))
|
||||
|
||||
# If top_p is less than 1.0, use top_p sampling
|
||||
if top_p < 1.0:
|
||||
# Only consider the 5000 largest logits to reduce computation
|
||||
sorted_logits, index = P.TopK(sorted=True)(logits, 5000)
|
||||
cumsum_logits = P.CumSum()(sorted_logits, 1)
|
||||
cumsum_logits = cumsum_logits.asnumpy()[0]
|
||||
index = index.asnumpy()[0]
|
||||
sorted_logits = sorted_logits.asnumpy()[0]
|
||||
top_p_num = sum(cumsum_logits > top_p)
|
||||
# In case the probability is smooth, the sum of 5000 largest probabilities are not large enough
|
||||
if top_p_num == 0:
|
||||
top_p_num = 5000
|
||||
# Get the corresponding probs and indices
|
||||
probs = sorted_logits[:top_p_num]
|
||||
p_args = index[:top_p_num]
|
||||
p = probs / sum(probs)
|
||||
# if top_p is set to 1.0, use top_k sampling
|
||||
else:
|
||||
# Get the corresponding probs and indices
|
||||
probs, p_args = P.TopK(sorted=True)(logits, top_k_num)
|
||||
probs = probs.asnumpy()[0]
|
||||
p_args = p_args.asnumpy()[0]
|
||||
# Avoid rounding error
|
||||
if sum(probs) == 0:
|
||||
probs = np.array([1 / top_k_num for _ in range(top_k_num)])
|
||||
p = probs / sum(probs)
|
||||
|
||||
# Random select a token as final output for this round
|
||||
target_index = np.random.choice(len(p), p=p)
|
||||
# Stop judgment
|
||||
if p_args[target_index] == end_token or valid_length == target_length-1:
|
||||
break
|
||||
|
||||
# Update frequency list
|
||||
target = p_args[target_index]
|
||||
frequency_list[0][target] = frequency_list[0][target] + 1
|
||||
valid_length += 1
|
||||
|
||||
batch_valid_length = Tensor(np.array([valid_length - 1]), mstype.int32)
|
||||
current_index = Tensor(np.array([0]), mstype.int32)
|
||||
input_id = Tensor([[target]], mstype.int32)
|
||||
# Update outputs with current generated token
|
||||
outputs.append(int(target))
|
||||
|
||||
# Call a single inference with input size of (bs, 1)
|
||||
logits = model.predict(input_id, current_index, init, batch_valid_length)
|
||||
# Return valid outputs out of padded outputs
|
||||
return np.array(outputs)
|
||||
|
|
|
@ -420,7 +420,28 @@ class Attention(nn.Cell):
|
|||
self.dense3.matmul.shard(((config.dp, 1), (config.mp, 1)))
|
||||
self.dense3.bias_add.shard(((config.dp, config.mp), (config.mp,)))
|
||||
|
||||
def construct(self, x, attention_mask, layer_past=None):
|
||||
self.is_first_iteration = True
|
||||
self.dtype = config.compute_dtype
|
||||
self.use_past = config.use_past
|
||||
if self.use_past:
|
||||
# operators used for state reuse
|
||||
seq_range = np.arange(config.seq_length).reshape(1, 1, -1)
|
||||
self.range = Tensor(np.tile(seq_range, (config.batch_size, 1, 1)), mstype.int32)
|
||||
self.seq_length = config.seq_length
|
||||
self.attention_mask = Tensor(np.tril(np.ones(shape=(self.seq_length, self.seq_length))), mstype.int32)
|
||||
self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
|
||||
self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
|
||||
self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
|
||||
self.expand_dims = P.ExpandDims().shard(((1, 1, 1),))
|
||||
self.tensor_le = P.LessEqual().shard(((1, 1, 1), (1, 1, 1)))
|
||||
self.add = P.TensorAdd().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
|
||||
self.equal = P.Equal().shard(((1, 1, 1), (1, 1, 1)))
|
||||
self.sub1 = P.Sub().shard(((1,), ()))
|
||||
self.tile = P.Tile().shard(((1, 1, 1, 1),))
|
||||
self.less = P.Less().shard(((1, 1, 1), (1, 1, 1)))
|
||||
self.mul1 = P.Mul().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
|
||||
|
||||
def construct(self, x, attention_mask, key_past=None, value_past=None, batch_valid_length=None):
|
||||
"""
|
||||
self-attention
|
||||
|
||||
|
@ -428,7 +449,9 @@ class Attention(nn.Cell):
|
|||
x: output of previous layer
|
||||
attention_mask: the attention mask matrix with shape (batch_size, 1,
|
||||
seq_length, seq_length)
|
||||
layer_past: the previous feature map
|
||||
key_past: previous saved key state
|
||||
value_past: previous saved value state
|
||||
batch_valid_length: the valid input seq_length without padding
|
||||
|
||||
Returns:
|
||||
output: Tensor, the output logit of this layer
|
||||
|
@ -458,12 +481,44 @@ class Attention(nn.Cell):
|
|||
value,
|
||||
(-1, original_shape[1], self.n_head, self.size_per_head)),
|
||||
(0, 2, 1, 3))
|
||||
|
||||
# key and value for current token(s)
|
||||
key_present = key
|
||||
value_present = value
|
||||
if self.use_past:
|
||||
past_value = layer_past[1]
|
||||
past_key = self.transpose(layer_past[0], (0, 1, 3, 2))
|
||||
key = self.concat_k((past_key, key))
|
||||
value = self.concat_v(past_value, value)
|
||||
layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value])
|
||||
# The first graph with the input size of (bs, seq_length)
|
||||
if self.is_first_iteration:
|
||||
# Get the valid input length without padding
|
||||
valid_length_vector = F.cast(self.less(self.range, batch_valid_length), self.dtype)
|
||||
# Cover the key and value numbers corresponding to the padding position
|
||||
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
|
||||
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
|
||||
# The second graph with the inpus size of (bs, 1)
|
||||
# the shape of query is (bs, num_heads, 1, size_per_head)
|
||||
# the shape of key is (bs, num_heads, size_per_head, 1)
|
||||
# the shape of value is (bs, num_heads, 1, size_per_head)
|
||||
else:
|
||||
# Get the current token position index
|
||||
valid_length = self.reducesum(F.cast(self.not_equal(self.slice(key_past, (0, 0, 0, 0),
|
||||
(F.shape(x)[0], 1, 1, self.seq_length),
|
||||
(1, 1, 1, 1)),
|
||||
0), mstype.float32), (1, 2, 3))
|
||||
valid_length = F.reshape(valid_length, (-1, 1, 1))
|
||||
valid_length_vector = F.cast(self.equal(valid_length, self.range), self.dtype)
|
||||
# Pad the key and value to seq_length with only the position index not zero
|
||||
current_key = self.mul1(self.tile(key, (1, 1, 1, self.seq_length)),
|
||||
self.expand_dims(valid_length_vector, 2))
|
||||
current_value = self.mul1(self.tile(value, (1, 1, self.seq_length, 1)),
|
||||
self.expand_dims(valid_length_vector, 3))
|
||||
# Concat the previous saved state and current state
|
||||
key = self.add(key_past, current_key)
|
||||
value = self.add(value_past, current_value)
|
||||
# Update key_present and value_present for state update
|
||||
key_present = key
|
||||
value_present = value
|
||||
attention_mask = F.reshape(self.attention_mask, (self.seq_length, self.seq_length, 1, 1))
|
||||
|
||||
layer_present = (key_present, value_present)
|
||||
# Self-attention considering attention mask
|
||||
attention = self._attn(query, key, value, attention_mask)
|
||||
# [bs, seq_length, embedding_size]
|
||||
|
@ -523,7 +578,7 @@ class Attention(nn.Cell):
|
|||
query = query / F.cast(self.coeff, F.dtype(query))
|
||||
key = key / F.cast(self.coeff, F.dtype(key))
|
||||
|
||||
# Attention score [bs, num_heads, seq_length, seq_length]
|
||||
# Attention score [bs, num_heads, seq_length_q, seq_length_k]
|
||||
score = self.batch_matmul(query, key)
|
||||
# Normalize after query and key MatMul, default on
|
||||
if self.scale:
|
||||
|
@ -533,6 +588,22 @@ class Attention(nn.Cell):
|
|||
|
||||
ori_dtype = P.DType()(score)
|
||||
score = P.Cast()(score, mstype.float32)
|
||||
|
||||
# for input size of (bs, 1) namely the second graph, the shape of attention_mask matrix should be
|
||||
# (bs, 1, 1, seq_length)
|
||||
if self.use_past and not self.is_first_iteration:
|
||||
# Calculate the current total token
|
||||
current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0),
|
||||
(F.shape(query)[0], 1, 1, self.seq_length),
|
||||
(1, 1, 1, 1)),
|
||||
0), mstype.float32), (1, 2, 3))
|
||||
# Get the precise position index
|
||||
index = self.sub1(F.cast(current_index, mstype.int32), 1)
|
||||
index = F.reshape(index, (-1, 1, 1))
|
||||
# Calculate the attention_mask matrix via the position index
|
||||
attention_mask = F.cast(self.tensor_le(self.range, index), mstype.int32)
|
||||
attention_mask = self.expand_dims(attention_mask, 2)
|
||||
|
||||
# Minus 10000 for the position where masked to exclude them from softmax
|
||||
multiplu_out = self.sub(
|
||||
P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
|
||||
|
@ -550,27 +621,28 @@ class Attention(nn.Cell):
|
|||
attention_probs = F.reshape(attention_probs, shape)
|
||||
|
||||
attention_probs = self.prob_dropout(attention_probs)
|
||||
# Weighted sum output [bs, num_heads, seq_length, size_per_head]
|
||||
# Weighted sum output [bs, num_heads, seq_length_q, size_per_head]
|
||||
weighted_values = self.batch_matmul(attention_probs, value)
|
||||
return weighted_values
|
||||
|
||||
|
||||
class Block(nn.Cell):
|
||||
class Decoder(nn.Cell):
|
||||
"""
|
||||
The basic block of PanguAlpha network
|
||||
The basic decoder structure of PanguAlpha network
|
||||
Args:
|
||||
config(PanguAlphaConfig): the config of network
|
||||
layer_idx: current layer index
|
||||
Inputs:
|
||||
x: the output of previous layer(input_ids for the first layer)
|
||||
attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
|
||||
layer_past: the previous feature map
|
||||
init_reset: whether reset the previous state
|
||||
batch_valid_length: the valid input seq_length without padding
|
||||
Returns:
|
||||
output: Tensor, the output logit of this layer
|
||||
layer_present: Tensor, the feature map of current layer
|
||||
"""
|
||||
def __init__(self, config, layer_idx):
|
||||
super(Block, self).__init__()
|
||||
super(Decoder, self).__init__()
|
||||
scale = 1 / math.sqrt(2.0 * config.num_layers)
|
||||
|
||||
if config.self_layernorm:
|
||||
|
@ -593,16 +665,43 @@ class Block(nn.Cell):
|
|||
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||
# Last activation of this layer will be saved for recompute in backward process
|
||||
self.dtype = config.compute_dtype
|
||||
self.use_past = config.use_past
|
||||
if self.use_past:
|
||||
# operator used for state reuse
|
||||
self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
|
||||
self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
|
||||
self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
|
||||
size_per_head = int(config.embedding_size / config.num_heads)
|
||||
self.key_shape = (config.batch_size, config.num_heads, size_per_head, config.seq_length)
|
||||
self.value_shape = (config.batch_size, config.num_heads, config.seq_length, size_per_head)
|
||||
# parameters saving key and value states
|
||||
self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
|
||||
self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
|
||||
self.tile = P.Tile().shard(((1, 1),))
|
||||
self.mul = P.Mul().shard(((1, 1, 1, 1), (1,)))
|
||||
self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
|
||||
|
||||
def construct(self, x, input_mask, layer_past=None):
|
||||
def construct(self, x, input_mask, init_reset=True, batch_valid_length=None):
|
||||
r"""
|
||||
The forward process of the block.
|
||||
"""
|
||||
# [bs, seq_length, embedding_size]
|
||||
input_x = self.layernorm1(x)
|
||||
input_x = F.cast(input_x, self.dtype)
|
||||
|
||||
# indicate whether reset saved states
|
||||
key_reset = None
|
||||
value_reset = None
|
||||
|
||||
if self.use_past:
|
||||
# reset states, init_reset True for reuse and False for reset
|
||||
key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
|
||||
value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
|
||||
# add dependency for desired execution order
|
||||
input_x = F.depend(input_x, key_reset)
|
||||
input_x = F.depend(input_x, value_reset)
|
||||
attention, layer_present = self.attention(input_x, input_mask,
|
||||
layer_past)
|
||||
self.key_past, self.value_past, batch_valid_length)
|
||||
# For post-layernorm the inputs for residual path are output of self-attention and output of layernorm
|
||||
if self.post_layernorm_residual:
|
||||
x = self.add(input_x, attention)
|
||||
|
@ -613,6 +712,22 @@ class Block(nn.Cell):
|
|||
output_x = self.layernorm2(x)
|
||||
output_x = F.cast(output_x, self.dtype)
|
||||
mlp_logit = self.output(output_x)
|
||||
|
||||
value_update = None
|
||||
key_update = None
|
||||
if self.use_past:
|
||||
# current key and value
|
||||
key_present, value_present = layer_present
|
||||
# update key and value calculated this step
|
||||
key_update = self.assign(self.key_past, key_present)
|
||||
value_update = self.assign(self.value_past, value_present)
|
||||
# add dependency for desired execution order
|
||||
key_update = F.depend(key_update, key_reset)
|
||||
value_update = F.depend(value_update, value_reset)
|
||||
|
||||
# add dependency for desired execution order
|
||||
mlp_logit = F.depend(mlp_logit, value_update)
|
||||
mlp_logit = F.depend(mlp_logit, key_update)
|
||||
if self.post_layernorm_residual:
|
||||
output = self.add(output_x, mlp_logit)
|
||||
else:
|
||||
|
@ -635,9 +750,14 @@ class PanguAlpha_EmbeddingPipeLine(nn.Cell):
|
|||
self.dropout = Dropout(1 - config.dropout_rate)
|
||||
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
|
||||
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
|
||||
self.use_past = config.use_past
|
||||
self.is_first_iteration = True
|
||||
|
||||
def construct(self, input_ids, table, input_position):
|
||||
def construct(self, input_ids, table, input_position, valid_index=None):
|
||||
input_embedding = self.word_embedding(input_ids, table)
|
||||
if self.use_past and not self.is_first_iteration:
|
||||
_, seq_length = F.shape(input_ids)
|
||||
input_position = valid_index.view(1, seq_length)
|
||||
position_embedding = self.position_embedding(input_position)
|
||||
hidden_states = self.add(input_embedding, position_embedding)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
@ -662,7 +782,7 @@ class QueryLayerAttention(Attention):
|
|||
r"""
|
||||
Self-Attention module using input query vector.
|
||||
"""
|
||||
def construct(self, x, query_hidden_state, attention_mask, layer_past=None):
|
||||
def construct(self, x, query_hidden_state, attention_mask, key_past=None, value_past=None, batch_valid_length=None):
|
||||
original_shape = F.shape(x)
|
||||
x = F.reshape(x, (-1, original_shape[-1]))
|
||||
query_hidden_state = F.reshape(query_hidden_state, (-1, original_shape[-1]))
|
||||
|
@ -684,12 +804,42 @@ class QueryLayerAttention(Attention):
|
|||
value,
|
||||
(-1, original_shape[1], self.n_head, self.size_per_head)),
|
||||
(0, 2, 1, 3))
|
||||
|
||||
key_present = key
|
||||
value_present = value
|
||||
if self.use_past:
|
||||
past_value = layer_past[1]
|
||||
past_key = self.transpose(layer_past[0], (0, 1, 3, 2))
|
||||
key = self.concat_k((past_key, key))
|
||||
value = self.concat_v(past_value, value)
|
||||
layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value])
|
||||
# The first graph with the input size of (bs, seq_length)
|
||||
if self.is_first_iteration:
|
||||
# Get the valid input length without padding
|
||||
valid_length_vector = F.cast(self.less(self.range, batch_valid_length), self.dtype)
|
||||
# Cover the key and value numbers corresponding to the padding position
|
||||
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
|
||||
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
|
||||
# The second graph with the inpus size of (bs, 1)
|
||||
# the shape of query is (bs, num_heads, 1, size_per_head)
|
||||
# the shape of key is (bs, num_heads, size_per_head, 1)
|
||||
# the shape of value is (bs, num_heads, 1, size_per_head)
|
||||
else:
|
||||
# Get the current token position index
|
||||
valid_length = self.reducesum(F.cast(self.not_equal(self.slice(key_past, (0, 0, 0, 0),
|
||||
(F.shape(x)[0], 1, 1, self.seq_length),
|
||||
(1, 1, 1, 1)),
|
||||
0), mstype.float32), (1, 2, 3))
|
||||
valid_length = F.reshape(valid_length, (-1, 1, 1))
|
||||
valid_length_vector = F.cast(self.equal(valid_length, self.range), self.dtype)
|
||||
# Pad the key and value to seq_length with only the position index not zero
|
||||
current_key = self.mul1(self.tile(key, (1, 1, 1, self.seq_length)),
|
||||
self.expand_dims(valid_length_vector, 2))
|
||||
current_value = self.mul1(self.tile(value, (1, 1, self.seq_length, 1)),
|
||||
self.expand_dims(valid_length_vector, 3))
|
||||
# Concat the previous saved state and current state
|
||||
key = self.add(key_past, current_key)
|
||||
value = self.add(value_past, current_value)
|
||||
# Update key_present and value_present for state update
|
||||
key_present = key
|
||||
value_present = value
|
||||
attention_mask = F.reshape(self.attention_mask, (self.seq_length, self.seq_length, 1, 1))
|
||||
layer_present = (key_present, value_present)
|
||||
attention = self._attn(query, key, value, attention_mask)
|
||||
attention_merge = self.merge_heads(attention)
|
||||
output = self.projection(attention_merge)
|
||||
|
@ -715,18 +865,48 @@ class QueryLayer(nn.Cell):
|
|||
self.post_layernorm_residual = config.post_layernorm_residual
|
||||
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||
self.dtype = config.compute_dtype
|
||||
self.use_past = config.use_past
|
||||
if self.use_past:
|
||||
# operator used for state reuse
|
||||
self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),))
|
||||
self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
|
||||
self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
|
||||
size_per_head = int(config.embedding_size / config.num_heads)
|
||||
self.key_shape = (config.batch_size, config.num_heads, size_per_head, config.seq_length)
|
||||
self.value_shape = (config.batch_size, config.num_heads, config.seq_length, size_per_head)
|
||||
# parameters saving key and value states
|
||||
self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
|
||||
self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
|
||||
self.tile = P.Tile().shard(((1, 1),))
|
||||
self.mul = P.Mul().shard(((1, 1, 1, 1), (1,)))
|
||||
self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
|
||||
|
||||
def construct(self, x, query_hidden_state, input_mask, layer_past=None):
|
||||
def construct(self, x, query_hidden_state, input_mask, init_reset=True, batch_valid_length=None):
|
||||
r"""
|
||||
Query Layer shares a similar structure with normal layer block
|
||||
except that it is not a traditional self-attention.
|
||||
"""
|
||||
input_x = self.layernorm1(x)
|
||||
input_x = F.cast(input_x, self.dtype)
|
||||
|
||||
# indicate whether reset saved states
|
||||
key_reset = None
|
||||
value_reset = None
|
||||
|
||||
if self.use_past:
|
||||
# reset states, init_reset True for reuse and False for reset
|
||||
key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype)))
|
||||
value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype)))
|
||||
# add dependency for desired execution order
|
||||
input_x = F.depend(input_x, key_reset)
|
||||
input_x = F.depend(input_x, value_reset)
|
||||
|
||||
attention, layer_present = self.attention(input_x,
|
||||
query_hidden_state,
|
||||
input_mask,
|
||||
layer_past)
|
||||
self.key_past,
|
||||
self.value_past,
|
||||
batch_valid_length)
|
||||
if self.post_layernorm_residual:
|
||||
x = self.add(input_x, attention)
|
||||
else:
|
||||
|
@ -735,28 +915,46 @@ class QueryLayer(nn.Cell):
|
|||
output_x = self.layernorm2(x)
|
||||
output_x = F.cast(output_x, self.dtype)
|
||||
mlp_logit = self.output(output_x)
|
||||
value_update = None
|
||||
key_update = None
|
||||
if self.use_past:
|
||||
# current key and value
|
||||
key_present, value_present = layer_present
|
||||
# update key and value calculated this step
|
||||
key_update = self.assign(self.key_past, key_present)
|
||||
value_update = self.assign(self.value_past, value_present)
|
||||
# add dependency for desired execution order
|
||||
key_update = F.depend(key_update, key_reset)
|
||||
value_update = F.depend(value_update, value_reset)
|
||||
|
||||
# add dependency for desired execution order
|
||||
mlp_logit = F.depend(mlp_logit, value_update)
|
||||
mlp_logit = F.depend(mlp_logit, key_update)
|
||||
|
||||
if self.post_layernorm_residual:
|
||||
output = self.add(output_x, mlp_logit)
|
||||
else:
|
||||
output = self.add(x, mlp_logit)
|
||||
return output, layer_present
|
||||
|
||||
class PanguAlpha_Model(nn.Cell):
|
||||
class Embedding(nn.Cell):
|
||||
"""
|
||||
The backbone of PanguAlpha network
|
||||
Input embedding, i.e., word embedding and position embedding
|
||||
Args:
|
||||
config(PanguAlphaConfig): the config of network
|
||||
Inputs:
|
||||
input_ids: the tokenized inputs with datatype int32
|
||||
input_mask: the mask indicating whether each position is a valid input
|
||||
layer_past: the previous feature map
|
||||
Returns:
|
||||
output_state: Tensor, the output logit of backbone
|
||||
present_layer: Tensor, the current feature map
|
||||
embedding_table: Tensor, the embedding table for the vocabulary
|
||||
input_position: the position index of each token
|
||||
attention_mask: the attention_mask attention for self-attention module
|
||||
valid_index: only used in incremental inference, the position index of current token
|
||||
outputs:
|
||||
hidden_states: Tensor, input embeddings
|
||||
attention_mask: Tensor, attention_mask matrix
|
||||
embedding_table: Tensor, embedding_table with shape of (vocab_size, embedding_size)
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(PanguAlpha_Model, self).__init__()
|
||||
super(Embedding, self).__init__()
|
||||
self.get_attention_mask = AttentionMask(config)
|
||||
# Word embedding
|
||||
self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1)
|
||||
|
@ -780,6 +978,59 @@ class PanguAlpha_Model(nn.Cell):
|
|||
self.position_embedding.embedding_table.parallel_optimizer = False
|
||||
self.position_embedding.gather.shard(((1, 1), (config.dp,)))
|
||||
self.position_embedding.expand.shard(((config.dp, 1),))
|
||||
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
||||
self.dropout = Dropout(1 - config.dropout_rate)
|
||||
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
|
||||
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
|
||||
self.eod_reset = config.eod_reset
|
||||
self.use_past = config.use_past
|
||||
self.is_first_iteration = True
|
||||
|
||||
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, valid_index=None):
|
||||
"""
|
||||
Calculate input embeddings via input token ids and input position
|
||||
"""
|
||||
# Word embedding
|
||||
input_embedding, embedding_table = self.word_embedding(input_ids)
|
||||
# If eod_reset disabled, there will be only one input from the dataset, i.e., input_ids
|
||||
# and the corresponding input_position and attention_mask will be derived from it.
|
||||
if not self.eod_reset:
|
||||
batch_size, seq_length = F.shape(input_ids)
|
||||
attention_mask = self.get_attention_mask(input_mask)
|
||||
if self.use_past and not self.is_first_iteration:
|
||||
input_position = valid_index.view(1, seq_length)
|
||||
else:
|
||||
input_position = F.tuple_to_array(F.make_range(seq_length))
|
||||
input_position = P.Tile()(input_position, (batch_size, 1))
|
||||
position_embedding = self.position_embedding(input_position)
|
||||
# Input features [bs, seq_length, embedding_size]
|
||||
hidden_states = self.add(input_embedding, position_embedding)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = P.Cast()(hidden_states, mstype.float16)
|
||||
attention_mask = self.expand_dims(attention_mask, 1)
|
||||
return hidden_states, attention_mask, embedding_table
|
||||
|
||||
|
||||
class PanguAlpha_Model(nn.Cell):
|
||||
"""
|
||||
The backbone of PanguAlpha network
|
||||
Args:
|
||||
config(PanguAlphaConfig): the config of network
|
||||
Inputs:
|
||||
input_ids: the tokenized inputs with datatype int32
|
||||
input_mask: the mask indicating whether each position is a valid input
|
||||
input_position: the position index of each token
|
||||
attention_mask: the attention_mask attention for self-attention module
|
||||
init_reset: whether reset saved key and value states
|
||||
batch_valid_length: the valid input sequence length without padding
|
||||
Returns:
|
||||
output_state: Tensor, the output logit of backbone
|
||||
embedding_table: Tensor, the embedding table for the vocabulary
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(PanguAlpha_Model, self).__init__()
|
||||
self.embedding = Embedding(config)
|
||||
self.blocks = nn.CellList()
|
||||
# Total fusion groups for HCCL operators. Specifically, the same tyep HCCL operators in same group will be fused.
|
||||
fusion_group_num = 4
|
||||
|
@ -794,9 +1045,9 @@ class PanguAlpha_Model(nn.Cell):
|
|||
print("After setting the layer is:", num_layers, flush=True)
|
||||
|
||||
for i in range(num_layers):
|
||||
per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
|
||||
# Each layer will be recomputed in the backward process. The output activation of each layer will be saved,
|
||||
# in other words, in backward process each block will be almost totally recomputed.
|
||||
per_block = Decoder(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
|
||||
# Each layer will be remoputed in the backward process. The output activation of each layer will be saved,
|
||||
# in other words, in backward process each block will be almosttotally recomputed.
|
||||
if config.use_recompute:
|
||||
per_block.recompute()
|
||||
self.blocks.append(per_block)
|
||||
|
@ -813,13 +1064,7 @@ class PanguAlpha_Model(nn.Cell):
|
|||
self.layernorm.beta.parallel_optimizer = False
|
||||
self.use_past = config.use_past
|
||||
self.past = tuple([None] * config.num_layers)
|
||||
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
||||
self.dtype = config.compute_dtype
|
||||
self.dropout = Dropout(1 - config.dropout_rate)
|
||||
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
|
||||
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
|
||||
self.eod_reset = config.eod_reset
|
||||
# If top_query_attention enabled, the input_position representing the position ids will be used as the index
|
||||
# for a query embedding table to obtain top query hidden states, together with the previous outputs of normal
|
||||
# self-attention layers, a new attention layer will be attached to the output of the model
|
||||
|
@ -853,33 +1098,18 @@ class PanguAlpha_Model(nn.Cell):
|
|||
self.use_top_query_attention = config.use_top_query_attention
|
||||
|
||||
|
||||
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, layer_past=None):
|
||||
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None,
|
||||
init_reset=True, batch_valid_length=None):
|
||||
"""PanguAlpha model"""
|
||||
if not self.use_past:
|
||||
layer_past = self.past
|
||||
# embedding for input_ids and the lower triangle like attention_mask matrix
|
||||
hidden_states, attention_mask, embedding_table = self.embedding(input_ids, input_mask,
|
||||
input_position, attention_mask,
|
||||
batch_valid_length)
|
||||
|
||||
# Word embedding
|
||||
input_embedding, embedding_table = self.word_embedding(input_ids)
|
||||
# If eod_reset disabled, there will be only one input from the dataset, i.e., input_ids
|
||||
# and the corresponding input_position and attention_mask will be derived from it.
|
||||
if not self.eod_reset:
|
||||
batch_size, seq_length = F.shape(input_ids)
|
||||
input_position = F.tuple_to_array(F.make_range(seq_length))
|
||||
input_position = P.Tile()(input_position, (batch_size, 1))
|
||||
attention_mask = self.get_attention_mask(input_mask)
|
||||
position_embedding = self.position_embedding(input_position)
|
||||
# Input features [bs, seq_length, embedding_size]
|
||||
hidden_states = self.add(input_embedding, position_embedding)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = P.Cast()(hidden_states, mstype.float16)
|
||||
attention_mask = self.expand_dims(attention_mask, 1)
|
||||
|
||||
present_layer = ()
|
||||
# Loop through each self-attention layer
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, present = self.blocks[i](hidden_states,
|
||||
attention_mask, layer_past)
|
||||
present_layer = present_layer + (present,)
|
||||
hidden_states, _ = self.blocks[i](hidden_states,
|
||||
attention_mask, init_reset, batch_valid_length)
|
||||
|
||||
output_state = self.layernorm(hidden_states)
|
||||
output_state = F.cast(output_state, self.dtype)
|
||||
|
@ -887,10 +1117,9 @@ class PanguAlpha_Model(nn.Cell):
|
|||
# Top query attention layer
|
||||
if self.use_top_query_attention:
|
||||
top_query_hidden_states = self.top_query_embedding(input_position)
|
||||
output_state, present = self.top_query_layer(output_state, top_query_hidden_states,
|
||||
attention_mask, layer_past)
|
||||
present_layer = present_layer + (present,)
|
||||
return output_state, present_layer, embedding_table
|
||||
output_state, _ = self.top_query_layer(output_state, top_query_hidden_states,
|
||||
attention_mask, init_reset, batch_valid_length)
|
||||
return output_state, embedding_table
|
||||
|
||||
class PanguAlpha_ModelPipeline(nn.Cell):
|
||||
"""
|
||||
|
@ -922,7 +1151,7 @@ class PanguAlpha_ModelPipeline(nn.Cell):
|
|||
self.top_query_embedding.pipeline_stage = i * config.stage_num // config.num_layers
|
||||
per_block = QueryLayer(config).set_comm_fusion(2)
|
||||
else:
|
||||
per_block = Block(config, i + 1).set_comm_fusion(2)
|
||||
per_block = Decoder(config, i + 1).set_comm_fusion(2)
|
||||
per_block.pipeline_stage = i * config.stage_num // config.num_layers
|
||||
per_block.recompute()
|
||||
self.blocks.append(per_block)
|
||||
|
@ -936,27 +1165,23 @@ class PanguAlpha_ModelPipeline(nn.Cell):
|
|||
self.layernorm.set_comm_fusion(2)
|
||||
self.layernorm.pipeline_stage = config.stage_num - 1
|
||||
self.use_past = config.use_past
|
||||
self.past = tuple([None] * config.num_layers)
|
||||
self.dtype = config.compute_dtype
|
||||
self.num_layers = config.num_layers
|
||||
|
||||
def construct(self, input_ids, input_mask, table, input_position, attention_mask, layer_past=None):
|
||||
def construct(self, input_ids, input_mask, table, input_position, attention_mask,
|
||||
init_reset=True, batch_valid_length=None):
|
||||
"""PanguAlpha model"""
|
||||
if not self.use_past:
|
||||
layer_past = self.past
|
||||
|
||||
hidden_states = self.pangu_alpha_embedding(input_ids, table, input_position)
|
||||
attention_mask = self.pangu_alpha_mask(input_mask, attention_mask)
|
||||
|
||||
present_layer = ()
|
||||
for i in range(self.num_layers-1):
|
||||
hidden_states, present = self.blocks[i](hidden_states,
|
||||
attention_mask, layer_past)
|
||||
present_layer = present_layer + (present,)
|
||||
hidden_states, _ = self.blocks[i](hidden_states,
|
||||
attention_mask, init_reset, batch_valid_length)
|
||||
|
||||
top_query_hidden_states = self.top_query_embedding(input_position)
|
||||
hidden_states, present = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states,
|
||||
attention_mask, layer_past)
|
||||
present_layer = present_layer + (present,)
|
||||
hidden_states, present_layer = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states,
|
||||
attention_mask, init_reset, batch_valid_length)
|
||||
output_state = self.layernorm(hidden_states)
|
||||
output_state = F.cast(output_state, self.dtype)
|
||||
return output_state, present_layer
|
||||
|
@ -998,7 +1223,10 @@ class PanguAlpha(nn.Cell):
|
|||
Inputs:
|
||||
input_ids: the tokenized inputs
|
||||
input_mask: the mask indicating whether each position is a valid input
|
||||
past: the previous feature map
|
||||
input_position: the position index of each token
|
||||
attention_mask: the attention_mask attention for self-attention module
|
||||
init_reset: whether reset saved key and value states
|
||||
batch_valid_length: the valid input sequence length without padding
|
||||
Returns:
|
||||
logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
|
||||
"""
|
||||
|
@ -1009,9 +1237,10 @@ class PanguAlpha(nn.Cell):
|
|||
# Network head to get logits over vocabulary
|
||||
self.head = PanguAlpha_Head(config)
|
||||
|
||||
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, past=None):
|
||||
output_states, _, embedding_table = self.backbone(
|
||||
input_ids, input_mask, input_position, attention_mask, past)
|
||||
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None,
|
||||
init_reset=True, batch_valid_length=None):
|
||||
output_states, embedding_table = self.backbone(
|
||||
input_ids, input_mask, input_position, attention_mask, init_reset, batch_valid_length)
|
||||
logits = self.head(output_states, embedding_table)
|
||||
return logits
|
||||
|
||||
|
@ -1039,9 +1268,10 @@ class PanguAlphaPipeline(nn.Cell):
|
|||
self.embedding_table.add_pipeline_stage(self.backbone.blocks[0].pipeline_stage)
|
||||
self.embedding_table.add_pipeline_stage(self.head.pipeline_stage)
|
||||
|
||||
def construct(self, input_ids, input_mask, input_position, attention_mask, past=None):
|
||||
def construct(self, input_ids, input_mask, input_position, attention_mask,
|
||||
init_reset=True, batch_valid_length=None):
|
||||
output_states, _ = self.backbone(input_ids, input_mask, self.embedding_table,
|
||||
input_position, attention_mask, past)
|
||||
input_position, attention_mask, init_reset, batch_valid_length)
|
||||
logits = self.head(output_states, self.embedding_table)
|
||||
return logits
|
||||
|
||||
|
@ -1213,24 +1443,27 @@ class EvalNet(nn.Cell):
|
|||
Inputs:
|
||||
input_ids: the tokenized inpus
|
||||
current_index: the index of current token
|
||||
init_reset: whether reset saved states
|
||||
Returns:
|
||||
outputs: Tensor, corresponding output for different tasks
|
||||
"""
|
||||
def __init__(self, backbone, generate=False):
|
||||
def __init__(self, backbone, generate=False, pad_token=6):
|
||||
super(EvalNet, self).__init__(auto_prefix=False)
|
||||
self.backbone = backbone
|
||||
self.pad_token = pad_token
|
||||
self.argmax = P.Argmax()
|
||||
self.generate = generate
|
||||
self.topk = P.TopK(sorted=True).shard(((1, 1),))
|
||||
self.gather = P.GatherV2().shard(((1, 1, 1), (1,)))
|
||||
self.log_softmax = P.LogSoftmax().shard(((1, 1),))
|
||||
self.gather = P.GatherV2().shard(((1, 1), (1,)))
|
||||
self.log_softmax = P.LogSoftmax().shard(((1, 1, 1),))
|
||||
|
||||
def construct(self, input_ids, current_index):
|
||||
def construct(self, input_ids, current_index, init_reset=True, batch_valid_length=None):
|
||||
"""evaluation net"""
|
||||
input_mask = F.cast(F.not_equal(input_ids, 0), mstype.float32)
|
||||
input_mask = F.cast(F.not_equal(input_ids, self.pad_token), mstype.float32)
|
||||
logits = self.backbone(input_ids, input_mask)
|
||||
bs, seq_length = F.shape(input_ids)
|
||||
logits = F.reshape(logits, (bs, seq_length, -1))
|
||||
logits = self.gather(logits, current_index, 1)
|
||||
index = current_index.view(1,)
|
||||
logits = self.gather(logits, index, 0)
|
||||
bs, _ = F.shape(input_ids)
|
||||
logits = logits.view(bs, 1, -1)
|
||||
log_probs = self.log_softmax(logits)
|
||||
return log_probs
|
||||
|
|
Loading…
Reference in New Issue