frequency_penalty for pangu

This commit is contained in:
wangjun 2021-06-03 11:12:40 +08:00
parent a3b04a78af
commit 7418f5fda9
4 changed files with 96 additions and 18 deletions

View File

@ -134,12 +134,12 @@ def run_predict(args_opt):
start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token)
input_ids = np.array(start_sentence).reshape(1, -1)
# Call inference
output_ids = generate(model_predict, input_ids, config.seq_length, 9)
output_ids = generate(model_predict, input_ids, opt)
# Decode output ids to sentence
output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist())
print('Output is:', output_samples, flush=True)
if __name__ == "__main__":
opt = get_args()
opt = get_args(True)
set_parse(opt)
run_predict(opt)

View File

@ -20,41 +20,87 @@ TopK for text generation
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
def generate(model, origin_inputs, seq_length, end_token=50256):
def generate(model, origin_inputs, config):
"""
TopK for text generation
Inputs:
model: the model for inferencing
origin_inputs: the original inputs based on which the model will continue writing
seq_length: seq_length for the model
end_token: end of sentence token id
config: inference configurations
Returns:
outputs: the ids for the generated text
"""
seq_length = seq_length
# Get configurations for inference
frequency_penalty = config.frequency_penalty
presence_penalty = config.presence_penalty
top_p = config.top_p
top_k_num = config.top_k_num
max_generate_length = config.max_generate_length
seq_length = config.seq_length
end_token = config.end_token
_, valid_length = origin_inputs.shape
# If target length exceeds seq_length, use seq_length instead
target_length = valid_length + max_generate_length
target_length = seq_length if target_length > seq_length else target_length
# A list of the frequency of each token
frequency_list = np.array([[0 for _ in range(config.vocab_size)]])
pad_length = seq_length - origin_inputs.shape[-1]
# Pad original inputs to seq_length
input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
print("input_ids is ", input_ids)
# A single loop generates one token, loop until reaching target seq_length or generating eod token
while valid_length < seq_length:
while valid_length < target_length:
inputs = Tensor(input_ids, mstype.int32)
# Indicate the exact token position
current_index = valid_length - 1 if valid_lenght - 1 > 0 else 0
current_index = Tensor([current_index], mstype.int32)
# Call a single inference
probs, p_args = model.predict(inputs)
# Get the topk value and index for the current position from the padded outputs
probs = probs.asnumpy()[valid_length-1, :]
p_args = p_args.asnumpy()[valid_length-1, :]
log_probs = model.predict(inputs, current_index)
# Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results
log_probs = log_probs.asnumpy().reshape(1, config.vocab_size)
log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty
# Convert the log_probs to probability
logits = P.Pow()(10, Tensor(log_probs_revised, mstype.float32))
# If top_p is less than 1.0, use top_p sampling
if top_p < 1.0:
# Only consider the 5000 largest logits to reduce computation
sorted_logits, index = P.TopK(sorted=True)(logits, 5000)
cumsum_logits = P.CumSum()(sorted_logits, 1)
cumsum_logits = cumsum_logits.asnumpy()[0]
index = index.asnumpy()[0]
sorted_logits = sorted_logits.asnumpy()[0]
top_p_num = sum(cumsum_logits > top_p)
# In case the probability is smooth, the sum of 5000 largest probabilities are not large enough
if top_p_num == 0:
top_p_num = 5000
# Get the corresponding probs and indices
probs = sorted_logits[:top_p_num]
p_args = index[:top_p_num]
p = probs / sum(probs)
# if top_p is set to 1.0, use top_k sampling
else:
# Get the corresponding probs and indices
probs, p_args = P.TopK(sorted=True)(logits, top_k_num)
probs = probs.asnumpy()[0]
p_args = p_args.asnumpy()[0]
# Avoid rounding error
if sum(probs) == 0:
probs = np.array([1 / top_k_num for _ in range(top_k_num)])
p = probs / sum(probs)
# Random select a token as final output for this round
p = probs
p = p / sum(p)
target_index = np.random.choice(len(p), p=p)
# Stop judgment
if p_args[target_index] == end_token or valid_length == seq_length-1:
if p_args[target_index] == end_token or valid_length == target_length-1:
outputs = input_ids
break
# Modify input_ids with newly generated token

View File

@ -1047,6 +1047,7 @@ class EvalNet(nn.Cell):
generate: enable generate mode
Inputs:
input_ids: the tokenized inpus
current_index: the index of current token
Returns:
outputs: Tensor, corresponding output for different tasks
"""
@ -1056,10 +1057,13 @@ class EvalNet(nn.Cell):
self.argmax = P.Argmax()
self.generate = generate
self.topk = P.TopK(sorted=True).shard(((1, 1),))
self.gather = P.GatherV2().shard(((1, 1), (1,)))
self.log_softmax = P.LogSoftmax().shard(((1, 1),))
def construct(self, input_ids):
def construct(self, input_ids, current_index):
"""evaluation net"""
input_mask = F.cast(F.not_equal(input_ids, 0), mstype.float32)
logits = self.backbone(input_ids, input_mask)
value, index = self.topk(logits, 5)
return value, index
logits = self.gather(logits, current_index)
log_probs = self.log_softmax(logits)
return log_probs

View File

@ -188,6 +188,32 @@ class LearningRate(LearningRateSchedule):
lr = decay_lr
return lr * self.lr_scale
def add_inference_params(opt):
"""Add inference params"""
opt.add_argument("--frequency_penalty",
type=float,
default=1.5,
help="coefficient for frequency_penalty")
opt.add_argument("--presence_penalty",
type=float,
default=0.3,
help="coefficient for presence_penalty")
opt.add_argument("--max_generate_length",
type=int,
default=500,
help="the maximum number of generated token")
opt.add_argument("--top_k_num",
type=int,
default=3,
help="the number for top_k sampling")
opt.add_argument("--top_p",
type=float,
default=1.0,
help="top_p sampling threshold, enabled if less than 1.0")
opt.add_argument("--end_token",
type=int,
default=9,
help="the token id for <end of document>")
def add_training_params(opt):
"""Add training params"""
@ -245,7 +271,7 @@ def add_training_params(opt):
default=2,
help="The sink size of the training")
def get_args():
def get_args(inference=False):
"""train function for PanguAlpha"""
parser = argparse.ArgumentParser(description="PanguAlpha training")
parser.add_argument('--device_id',
@ -301,6 +327,8 @@ def get_args():
help="The initialization type for parameters. Default fp32.")
add_training_params(parser)
if inference:
add_inference_params(parser)
args_opt = parser.parse_args()
return args_opt