forked from mindspore-Ecosystem/mindspore
frequency_penalty for pangu
This commit is contained in:
parent
a3b04a78af
commit
7418f5fda9
|
@ -134,12 +134,12 @@ def run_predict(args_opt):
|
|||
start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token)
|
||||
input_ids = np.array(start_sentence).reshape(1, -1)
|
||||
# Call inference
|
||||
output_ids = generate(model_predict, input_ids, config.seq_length, 9)
|
||||
output_ids = generate(model_predict, input_ids, opt)
|
||||
# Decode output ids to sentence
|
||||
output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist())
|
||||
print('Output is:', output_samples, flush=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
opt = get_args()
|
||||
opt = get_args(True)
|
||||
set_parse(opt)
|
||||
run_predict(opt)
|
||||
|
|
|
@ -20,41 +20,87 @@ TopK for text generation
|
|||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
def generate(model, origin_inputs, seq_length, end_token=50256):
|
||||
def generate(model, origin_inputs, config):
|
||||
"""
|
||||
TopK for text generation
|
||||
|
||||
Inputs:
|
||||
model: the model for inferencing
|
||||
origin_inputs: the original inputs based on which the model will continue writing
|
||||
seq_length: seq_length for the model
|
||||
end_token: end of sentence token id
|
||||
config: inference configurations
|
||||
|
||||
Returns:
|
||||
outputs: the ids for the generated text
|
||||
"""
|
||||
seq_length = seq_length
|
||||
# Get configurations for inference
|
||||
frequency_penalty = config.frequency_penalty
|
||||
presence_penalty = config.presence_penalty
|
||||
top_p = config.top_p
|
||||
top_k_num = config.top_k_num
|
||||
max_generate_length = config.max_generate_length
|
||||
seq_length = config.seq_length
|
||||
end_token = config.end_token
|
||||
|
||||
_, valid_length = origin_inputs.shape
|
||||
# If target length exceeds seq_length, use seq_length instead
|
||||
target_length = valid_length + max_generate_length
|
||||
target_length = seq_length if target_length > seq_length else target_length
|
||||
|
||||
# A list of the frequency of each token
|
||||
frequency_list = np.array([[0 for _ in range(config.vocab_size)]])
|
||||
pad_length = seq_length - origin_inputs.shape[-1]
|
||||
# Pad original inputs to seq_length
|
||||
input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
|
||||
print("input_ids is ", input_ids)
|
||||
|
||||
# A single loop generates one token, loop until reaching target seq_length or generating eod token
|
||||
while valid_length < seq_length:
|
||||
while valid_length < target_length:
|
||||
inputs = Tensor(input_ids, mstype.int32)
|
||||
# Indicate the exact token position
|
||||
current_index = valid_length - 1 if valid_lenght - 1 > 0 else 0
|
||||
current_index = Tensor([current_index], mstype.int32)
|
||||
# Call a single inference
|
||||
probs, p_args = model.predict(inputs)
|
||||
# Get the topk value and index for the current position from the padded outputs
|
||||
probs = probs.asnumpy()[valid_length-1, :]
|
||||
p_args = p_args.asnumpy()[valid_length-1, :]
|
||||
log_probs = model.predict(inputs, current_index)
|
||||
# Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results
|
||||
log_probs = log_probs.asnumpy().reshape(1, config.vocab_size)
|
||||
log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty
|
||||
|
||||
# Convert the log_probs to probability
|
||||
logits = P.Pow()(10, Tensor(log_probs_revised, mstype.float32))
|
||||
|
||||
# If top_p is less than 1.0, use top_p sampling
|
||||
if top_p < 1.0:
|
||||
# Only consider the 5000 largest logits to reduce computation
|
||||
sorted_logits, index = P.TopK(sorted=True)(logits, 5000)
|
||||
cumsum_logits = P.CumSum()(sorted_logits, 1)
|
||||
cumsum_logits = cumsum_logits.asnumpy()[0]
|
||||
index = index.asnumpy()[0]
|
||||
sorted_logits = sorted_logits.asnumpy()[0]
|
||||
top_p_num = sum(cumsum_logits > top_p)
|
||||
# In case the probability is smooth, the sum of 5000 largest probabilities are not large enough
|
||||
if top_p_num == 0:
|
||||
top_p_num = 5000
|
||||
# Get the corresponding probs and indices
|
||||
probs = sorted_logits[:top_p_num]
|
||||
p_args = index[:top_p_num]
|
||||
p = probs / sum(probs)
|
||||
# if top_p is set to 1.0, use top_k sampling
|
||||
else:
|
||||
# Get the corresponding probs and indices
|
||||
probs, p_args = P.TopK(sorted=True)(logits, top_k_num)
|
||||
probs = probs.asnumpy()[0]
|
||||
p_args = p_args.asnumpy()[0]
|
||||
# Avoid rounding error
|
||||
if sum(probs) == 0:
|
||||
probs = np.array([1 / top_k_num for _ in range(top_k_num)])
|
||||
p = probs / sum(probs)
|
||||
|
||||
# Random select a token as final output for this round
|
||||
p = probs
|
||||
p = p / sum(p)
|
||||
target_index = np.random.choice(len(p), p=p)
|
||||
# Stop judgment
|
||||
if p_args[target_index] == end_token or valid_length == seq_length-1:
|
||||
if p_args[target_index] == end_token or valid_length == target_length-1:
|
||||
outputs = input_ids
|
||||
break
|
||||
# Modify input_ids with newly generated token
|
||||
|
|
|
@ -1047,6 +1047,7 @@ class EvalNet(nn.Cell):
|
|||
generate: enable generate mode
|
||||
Inputs:
|
||||
input_ids: the tokenized inpus
|
||||
current_index: the index of current token
|
||||
Returns:
|
||||
outputs: Tensor, corresponding output for different tasks
|
||||
"""
|
||||
|
@ -1056,10 +1057,13 @@ class EvalNet(nn.Cell):
|
|||
self.argmax = P.Argmax()
|
||||
self.generate = generate
|
||||
self.topk = P.TopK(sorted=True).shard(((1, 1),))
|
||||
self.gather = P.GatherV2().shard(((1, 1), (1,)))
|
||||
self.log_softmax = P.LogSoftmax().shard(((1, 1),))
|
||||
|
||||
def construct(self, input_ids):
|
||||
def construct(self, input_ids, current_index):
|
||||
"""evaluation net"""
|
||||
input_mask = F.cast(F.not_equal(input_ids, 0), mstype.float32)
|
||||
logits = self.backbone(input_ids, input_mask)
|
||||
value, index = self.topk(logits, 5)
|
||||
return value, index
|
||||
logits = self.gather(logits, current_index)
|
||||
log_probs = self.log_softmax(logits)
|
||||
return log_probs
|
||||
|
|
|
@ -188,6 +188,32 @@ class LearningRate(LearningRateSchedule):
|
|||
lr = decay_lr
|
||||
return lr * self.lr_scale
|
||||
|
||||
def add_inference_params(opt):
|
||||
"""Add inference params"""
|
||||
opt.add_argument("--frequency_penalty",
|
||||
type=float,
|
||||
default=1.5,
|
||||
help="coefficient for frequency_penalty")
|
||||
opt.add_argument("--presence_penalty",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="coefficient for presence_penalty")
|
||||
opt.add_argument("--max_generate_length",
|
||||
type=int,
|
||||
default=500,
|
||||
help="the maximum number of generated token")
|
||||
opt.add_argument("--top_k_num",
|
||||
type=int,
|
||||
default=3,
|
||||
help="the number for top_k sampling")
|
||||
opt.add_argument("--top_p",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="top_p sampling threshold, enabled if less than 1.0")
|
||||
opt.add_argument("--end_token",
|
||||
type=int,
|
||||
default=9,
|
||||
help="the token id for <end of document>")
|
||||
|
||||
def add_training_params(opt):
|
||||
"""Add training params"""
|
||||
|
@ -245,7 +271,7 @@ def add_training_params(opt):
|
|||
default=2,
|
||||
help="The sink size of the training")
|
||||
|
||||
def get_args():
|
||||
def get_args(inference=False):
|
||||
"""train function for PanguAlpha"""
|
||||
parser = argparse.ArgumentParser(description="PanguAlpha training")
|
||||
parser.add_argument('--device_id',
|
||||
|
@ -301,6 +327,8 @@ def get_args():
|
|||
help="The initialization type for parameters. Default fp32.")
|
||||
|
||||
add_training_params(parser)
|
||||
if inference:
|
||||
add_inference_params(parser)
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
return args_opt
|
||||
|
|
Loading…
Reference in New Issue