diff --git a/model_zoo/official/nlp/pangu_alpha/predict.py b/model_zoo/official/nlp/pangu_alpha/predict.py index db325be8e2b..0d8ce6bd9d7 100644 --- a/model_zoo/official/nlp/pangu_alpha/predict.py +++ b/model_zoo/official/nlp/pangu_alpha/predict.py @@ -134,12 +134,12 @@ def run_predict(args_opt): start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token) input_ids = np.array(start_sentence).reshape(1, -1) # Call inference - output_ids = generate(model_predict, input_ids, config.seq_length, 9) + output_ids = generate(model_predict, input_ids, opt) # Decode output ids to sentence output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist()) print('Output is:', output_samples, flush=True) if __name__ == "__main__": - opt = get_args() + opt = get_args(True) set_parse(opt) run_predict(opt) diff --git a/model_zoo/official/nlp/pangu_alpha/src/generate.py b/model_zoo/official/nlp/pangu_alpha/src/generate.py index fc69156de40..a729dc9b23e 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/generate.py +++ b/model_zoo/official/nlp/pangu_alpha/src/generate.py @@ -20,41 +20,87 @@ TopK for text generation import numpy as np import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P -def generate(model, origin_inputs, seq_length, end_token=50256): +def generate(model, origin_inputs, config): """ TopK for text generation Inputs: model: the model for inferencing origin_inputs: the original inputs based on which the model will continue writing - seq_length: seq_length for the model - end_token: end of sentence token id + config: inference configurations Returns: outputs: the ids for the generated text """ - seq_length = seq_length + # Get configurations for inference + frequency_penalty = config.frequency_penalty + presence_penalty = config.presence_penalty + top_p = config.top_p + top_k_num = config.top_k_num + max_generate_length = config.max_generate_length + seq_length = config.seq_length + end_token = config.end_token + _, valid_length = origin_inputs.shape + # If target length exceeds seq_length, use seq_length instead + target_length = valid_length + max_generate_length + target_length = seq_length if target_length > seq_length else target_length + + # A list of the frequency of each token + frequency_list = np.array([[0 for _ in range(config.vocab_size)]]) pad_length = seq_length - origin_inputs.shape[-1] # Pad original inputs to seq_length input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0)) print("input_ids is ", input_ids) # A single loop generates one token, loop until reaching target seq_length or generating eod token - while valid_length < seq_length: + while valid_length < target_length: inputs = Tensor(input_ids, mstype.int32) + # Indicate the exact token position + current_index = valid_length - 1 if valid_lenght - 1 > 0 else 0 + current_index = Tensor([current_index], mstype.int32) # Call a single inference - probs, p_args = model.predict(inputs) - # Get the topk value and index for the current position from the padded outputs - probs = probs.asnumpy()[valid_length-1, :] - p_args = p_args.asnumpy()[valid_length-1, :] + log_probs = model.predict(inputs, current_index) + # Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results + log_probs = log_probs.asnumpy().reshape(1, config.vocab_size) + log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty + + # Convert the log_probs to probability + logits = P.Pow()(10, Tensor(log_probs_revised, mstype.float32)) + + # If top_p is less than 1.0, use top_p sampling + if top_p < 1.0: + # Only consider the 5000 largest logits to reduce computation + sorted_logits, index = P.TopK(sorted=True)(logits, 5000) + cumsum_logits = P.CumSum()(sorted_logits, 1) + cumsum_logits = cumsum_logits.asnumpy()[0] + index = index.asnumpy()[0] + sorted_logits = sorted_logits.asnumpy()[0] + top_p_num = sum(cumsum_logits > top_p) + # In case the probability is smooth, the sum of 5000 largest probabilities are not large enough + if top_p_num == 0: + top_p_num = 5000 + # Get the corresponding probs and indices + probs = sorted_logits[:top_p_num] + p_args = index[:top_p_num] + p = probs / sum(probs) + # if top_p is set to 1.0, use top_k sampling + else: + # Get the corresponding probs and indices + probs, p_args = P.TopK(sorted=True)(logits, top_k_num) + probs = probs.asnumpy()[0] + p_args = p_args.asnumpy()[0] + # Avoid rounding error + if sum(probs) == 0: + probs = np.array([1 / top_k_num for _ in range(top_k_num)]) + p = probs / sum(probs) + # Random select a token as final output for this round - p = probs - p = p / sum(p) target_index = np.random.choice(len(p), p=p) # Stop judgment - if p_args[target_index] == end_token or valid_length == seq_length-1: + if p_args[target_index] == end_token or valid_length == target_length-1: outputs = input_ids break # Modify input_ids with newly generated token diff --git a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py index 69386c6b393..f7211cc4bdd 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py +++ b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py @@ -1047,6 +1047,7 @@ class EvalNet(nn.Cell): generate: enable generate mode Inputs: input_ids: the tokenized inpus + current_index: the index of current token Returns: outputs: Tensor, corresponding output for different tasks """ @@ -1056,10 +1057,13 @@ class EvalNet(nn.Cell): self.argmax = P.Argmax() self.generate = generate self.topk = P.TopK(sorted=True).shard(((1, 1),)) + self.gather = P.GatherV2().shard(((1, 1), (1,))) + self.log_softmax = P.LogSoftmax().shard(((1, 1),)) - def construct(self, input_ids): + def construct(self, input_ids, current_index): """evaluation net""" input_mask = F.cast(F.not_equal(input_ids, 0), mstype.float32) logits = self.backbone(input_ids, input_mask) - value, index = self.topk(logits, 5) - return value, index + logits = self.gather(logits, current_index) + log_probs = self.log_softmax(logits) + return log_probs diff --git a/model_zoo/official/nlp/pangu_alpha/src/utils.py b/model_zoo/official/nlp/pangu_alpha/src/utils.py index a19d75481f0..170cf288539 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/utils.py +++ b/model_zoo/official/nlp/pangu_alpha/src/utils.py @@ -188,6 +188,32 @@ class LearningRate(LearningRateSchedule): lr = decay_lr return lr * self.lr_scale +def add_inference_params(opt): + """Add inference params""" + opt.add_argument("--frequency_penalty", + type=float, + default=1.5, + help="coefficient for frequency_penalty") + opt.add_argument("--presence_penalty", + type=float, + default=0.3, + help="coefficient for presence_penalty") + opt.add_argument("--max_generate_length", + type=int, + default=500, + help="the maximum number of generated token") + opt.add_argument("--top_k_num", + type=int, + default=3, + help="the number for top_k sampling") + opt.add_argument("--top_p", + type=float, + default=1.0, + help="top_p sampling threshold, enabled if less than 1.0") + opt.add_argument("--end_token", + type=int, + default=9, + help="the token id for ") def add_training_params(opt): """Add training params""" @@ -245,7 +271,7 @@ def add_training_params(opt): default=2, help="The sink size of the training") -def get_args(): +def get_args(inference=False): """train function for PanguAlpha""" parser = argparse.ArgumentParser(description="PanguAlpha training") parser.add_argument('--device_id', @@ -301,6 +327,8 @@ def get_args(): help="The initialization type for parameters. Default fp32.") add_training_params(parser) + if inference: + add_inference_params(parser) args_opt = parser.parse_args() return args_opt