DAMO-ConvAI/pcll/utils.py

520 lines
24 KiB
Python

import logging
import random
import torch
import numpy as np
from torch.utils.data import DataLoader
from dataset import PadBatchSeq, pad_seq
import json
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
import torch.distributed as dist
import os, time, gc, json, pickle, argparse, math, re
import torch.nn as nn
import torch.utils.data as data
import torch.distributed as dist
import torch.multiprocessing as mp
import copy
from generate import *
loggers = {}
def get_logger(filename, level=logging.INFO, print2screen=True):
global loggers
import logging
if os.path.exists(filename):
os.remove(filename)
if loggers.get(filename):
return loggers.get(filename)
else:
logger = logging.getLogger(filename)
logger.setLevel(level)
fh = logging.FileHandler(filename, encoding='utf-8')
fh.setLevel(level)
ch = logging.StreamHandler()
ch.setLevel(level)
formatter = logging.Formatter('[%(asctime)s][%(filename)s][line: %(lineno)d][%(levelname)s] >> %(message)s')
# formatter = logging.Formatter('[%(asctime)s][%(thread)d][%(filename)s][line: %(lineno)d][%(levelname)s] >> %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
if print2screen:
logger.addHandler(ch)
loggers[filename] = logger
return logger
def num_params(model):
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
def init_para_frompretrained(m, pm, share_para=False):
# m.wte.weight = pm.wte.weight
# m.wpe.weight = pm.wpe.weight
m.wte.weight = copy.deepcopy(pm.wte.weight)
m.wpe.weight = copy.deepcopy(pm.wpe.weight)
for i in range(min(len(m.h), len(pm.h))):
m.h[i].ln_1.weight = pm.h[i].ln_1.weight if share_para else copy.deepcopy(pm.h[i].ln_1.weight)
# print('ln_1',type(m.h[i].ln_1.weight),m.h[i].ln_1.weight,flush=True)
m.h[i].ln_1.bias = pm.h[i].ln_1.bias if share_para else copy.deepcopy(pm.h[i].ln_1.bias)
m.h[i].attn.c_attn.weight = pm.h[i].attn.c_attn.weight if share_para else copy.deepcopy(pm.h[i].attn.c_attn.weight)
m.h[i].attn.c_attn.bias = pm.h[i].attn.c_attn.bias if share_para else copy.deepcopy(pm.h[i].attn.c_attn.bias)
m.h[i].attn.c_proj.weight = pm.h[i].attn.c_proj.weight if share_para else copy.deepcopy(pm.h[i].attn.c_proj.weight)
m.h[i].attn.c_proj.bias = pm.h[i].attn.c_proj.bias if share_para else copy.deepcopy(pm.h[i].attn.c_proj.bias)
m.h[i].ln_2.weight = pm.h[i].ln_2.weight if share_para else copy.deepcopy(pm.h[i].ln_2.weight)
m.h[i].ln_2.bias = pm.h[i].ln_2.bias if share_para else copy.deepcopy(pm.h[i].ln_2.bias)
m.h[i].mlp.c_fc.weight = pm.h[i].mlp.c_fc.weight if share_para else copy.deepcopy(pm.h[i].mlp.c_fc.weight)
m.h[i].mlp.c_fc.bias = pm.h[i].mlp.c_fc.bias if share_para else copy.deepcopy(pm.h[i].mlp.c_fc.bias)
m.h[i].mlp.c_proj.weight = pm.h[i].mlp.c_proj.weight if share_para else copy.deepcopy(pm.h[i].mlp.c_proj.weight)
m.h[i].mlp.c_proj.bias = pm.h[i].mlp.c_proj.bias if share_para else copy.deepcopy(pm.h[i].mlp.c_proj.bias)
# if isdecoder:
# m.h_ori[i].ln_1.weight = pm.h[i].ln_1.weight if share_para else copy.copy(pm.h[i].ln_1.weight)
# m.h_ori[i].ln_1.bias = pm.h[i].ln_1.bias if share_para else copy.copy(pm.h[i].ln_1.bias)
# m.h_ori[i].attn.c_attn.weight = pm.h[i].attn.c_attn.weight if share_para else copy.copy(pm.h[i].attn.c_attn.weight)
# m.h_ori[i].attn.c_attn.bias = pm.h[i].attn.c_attn.bias if share_para else copy.copy(pm.h[i].attn.c_attn.bias)
# m.h_ori[i].attn.c_proj.weight = pm.h[i].attn.c_proj.weight if share_para else copy.copy(pm.h[i].attn.c_proj.weight)
# m.h_ori[i].attn.c_proj.bias = pm.h[i].attn.c_proj.bias if share_para else copy.copy(pm.h[i].attn.c_proj.bias)
# m.h_ori[i].ln_2.weight = pm.h[i].ln_2.weight if share_para else copy.copy(pm.h[i].ln_2.weight)
# m.h_ori[i].ln_2.bias = pm.h[i].ln_2.bias if share_para else copy.copy(pm.h[i].ln_2.bias)
# m.h_ori[i].mlp.c_fc.weight = pm.h[i].mlp.c_fc.weight if share_para else copy.copy(pm.h[i].mlp.c_fc.weight)
# m.h_ori[i].mlp.c_fc.bias = pm.h[i].mlp.c_fc.bias if share_para else copy.copy(pm.h[i].mlp.c_fc.bias)
# m.h_ori[i].mlp.c_proj.weight = pm.h[i].mlp.c_proj.weight if share_para else copy.copy(pm.h[i].mlp.c_proj.weight)
# m.h_ori[i].mlp.c_proj.bias = pm.h[i].mlp.c_proj.bias if share_para else copy.copy(pm.h[i].mlp.c_proj.bias)
m.ln_f.weight = pm.ln_f.weight if share_para else copy.deepcopy(pm.ln_f.weight)
# m.ln_f.weight = pm.ln_f.weight if share_para else copy.copy(pm.ln_f.weight)
m.ln_f.bias = pm.ln_f.bias if share_para else copy.deepcopy(pm.ln_f.bias)
# m.ln_f.bias = pm.ln_f.bias if share_para else copy.copy(pm.ln_f.bias)
def init_params(m, pm, share_para=False):
m.wte.weight = pm.wte.weight
m.wpe.weight = pm.wpe.weight
# for i in range(min(len(m.h), len(pm.h))):
# m.h[i].ln_1.weight = pm.h[i].ln_1.weight if share_para else copy.copy(pm.h[i].ln_1.weight)
# m.h[i].ln_1.bias = pm.h[i].ln_1.bias if share_para else copy.copy(pm.h[i].ln_1.bias)
# m.h[i].attn.c_attn.weight = pm.h[i].attn.c_attn.weight if share_para else copy.copy(pm.h[i].attn.c_attn.weight)
# m.h[i].attn.c_attn.bias = pm.h[i].attn.c_attn.bias if share_para else copy.copy(pm.h[i].attn.c_attn.bias)
# m.h[i].attn.c_proj.weight = pm.h[i].attn.c_proj.weight if share_para else copy.copy(pm.h[i].attn.c_proj.weight)
# m.h[i].attn.c_proj.bias = pm.h[i].attn.c_proj.bias if share_para else copy.copy(pm.h[i].attn.c_proj.bias)
# m.h[i].ln_2.weight = pm.h[i].ln_2.weight if share_para else copy.copy(pm.h[i].ln_2.weight)
# m.h[i].ln_2.bias = pm.h[i].ln_2.bias if share_para else copy.copy(pm.h[i].ln_2.bias)
# m.h[i].mlp.c_fc.weight = pm.h[i].mlp.c_fc.weight if share_para else copy.copy(pm.h[i].mlp.c_fc.weight)
# m.h[i].mlp.c_fc.bias = pm.h[i].mlp.c_fc.bias if share_para else copy.copy(pm.h[i].mlp.c_fc.bias)
# m.h[i].mlp.c_proj.weight = pm.h[i].mlp.c_proj.weight if share_para else copy.copy(pm.h[i].mlp.c_proj.weight)
# m.h[i].mlp.c_proj.bias = pm.h[i].mlp.c_proj.bias if share_para else copy.copy(pm.h[i].mlp.c_proj.bias)
# if isdecoder:
# m.h_ori[i].ln_1.weight = pm.h[i].ln_1.weight if share_para else copy.copy(pm.h[i].ln_1.weight)
# m.h_ori[i].ln_1.bias = pm.h[i].ln_1.bias if share_para else copy.copy(pm.h[i].ln_1.bias)
# m.h_ori[i].attn.c_attn.weight = pm.h[i].attn.c_attn.weight if share_para else copy.copy(pm.h[i].attn.c_attn.weight)
# m.h_ori[i].attn.c_attn.bias = pm.h[i].attn.c_attn.bias if share_para else copy.copy(pm.h[i].attn.c_attn.bias)
# m.h_ori[i].attn.c_proj.weight = pm.h[i].attn.c_proj.weight if share_para else copy.copy(pm.h[i].attn.c_proj.weight)
# m.h_ori[i].attn.c_proj.bias = pm.h[i].attn.c_proj.bias if share_para else copy.copy(pm.h[i].attn.c_proj.bias)
# m.h_ori[i].ln_2.weight = pm.h[i].ln_2.weight if share_para else copy.copy(pm.h[i].ln_2.weight)
# m.h_ori[i].ln_2.bias = pm.h[i].ln_2.bias if share_para else copy.copy(pm.h[i].ln_2.bias)
# m.h_ori[i].mlp.c_fc.weight = pm.h[i].mlp.c_fc.weight if share_para else copy.copy(pm.h[i].mlp.c_fc.weight)
# m.h_ori[i].mlp.c_fc.bias = pm.h[i].mlp.c_fc.bias if share_para else copy.copy(pm.h[i].mlp.c_fc.bias)
# m.h_ori[i].mlp.c_proj.weight = pm.h[i].mlp.c_proj.weight if share_para else copy.copy(pm.h[i].mlp.c_proj.weight)
# m.h_ori[i].mlp.c_proj.bias = pm.h[i].mlp.c_proj.bias if share_para else copy.copy(pm.h[i].mlp.c_proj.bias)
m.ln_f.weight = pm.ln_f.weight if share_para else copy.copy(pm.ln_f.weight)
m.ln_f.bias = pm.ln_f.bias if share_para else copy.copy(pm.ln_f.bias)
def switch_schedule(schedule, mult, switch):
""" Apply LR multiplier before iteration "switch" """
def f(e):
s = schedule(e)
if e < switch:
return s * mult
return s
return f
def linear_schedule(args):
def f(e):
if e <= args.warmup:
return e / args.warmup
return max((e - args.iterations) / (args.warmup - args.iterations), 0)
return f
def get_mask(x_len):
mask = torch.arange(max(x_len), device=x_len.device)[None, :] < x_len[:, None] # [bs, max_len]
return mask.bool()
def get_reverse_mask(x_len):
mask = torch.arange(max(x_len)-1, -1, -1, device=x_len.device)[None, :] < x_len[:, None] # [bs, max_len]
return mask.bool()
def compare_tokens(x, y, eos_id):
if eos_id in x:
x = x[:x.index(eos_id)]
if eos_id in y:
y = y[:y.index(eos_id)]
return x == y
def pad_tensor(tensor, length, pad_token=0):
return torch.cat([tensor, tensor.new(tensor.size(0), length - tensor.size()[1]).fill_(pad_token)], dim=1)
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(seed)
torch.cuda.manual_seed_all(seed)
def communicate_tensor(tensor_list, pad_token=0):
'''
collect tensors from all processes
'''
if len(tensor_list) == 0:
return None
device = tensor_list[0].device
max_len = torch.tensor(max([i.shape[1] for i in tensor_list]), dtype=torch.int64, device=device)
if dist.is_initialized(): # Obtain the max_len of the second dim of each tensor
dist.all_reduce(max_len, op=dist.ReduceOp.MAX)
# Pad tensors to the max_len
tensor = torch.cat([pad_tensor(i, max_len, pad_token) for i in tensor_list], dim=0)
tensor_bs = torch.tensor(tensor.shape[0], dtype=torch.int64, device=device)
max_tensor_bs = torch.tensor(tensor.shape[0], dtype=torch.int64, device=device)
if dist.is_initialized():
dist.all_reduce(max_tensor_bs, op=dist.ReduceOp.MAX) # Obtain the max_tensor_bs of each tensor
if max_tensor_bs != tensor_bs:
tensor = torch.cat([tensor, tensor.new(max_tensor_bs-tensor_bs, tensor.shape[1]).fill_(pad_token)], dim=0)
# Gather padded tensors and the bs of each tensor
tensor_list = [torch.ones_like(tensor).fill_(pad_token) for _ in range(dist.get_world_size())]
tensor_bs_list = [torch.ones_like(tensor_bs).fill_(pad_token) for _ in range(dist.get_world_size())]
dist.all_gather(tensor_list=tensor_list, tensor=tensor.contiguous())
dist.all_gather(tensor_list=tensor_bs_list, tensor=tensor_bs)
# Cut the padded batch
for i in range(dist.get_world_size()):
tensor_list[i] = tensor_list[i][:tensor_bs_list[i]]
tensor = torch.cat(tensor_list, dim=0)
return tensor
class MetricsRecorder(object):
def __init__(self, device, *metric_names):
self.metric_names = list(metric_names)
self.device = device
self.metrics = {}
for metric_name in metric_names:
self.metrics[metric_name] = torch.tensor(0, dtype=torch.float64, device=self.device)
def metric_update(self, batch_no, metric_values_dict):
for k, v in metric_values_dict.items():
self.metrics[k] = (self.metrics[k] * batch_no + v) / (batch_no + 1)
def add_to_writer(self, writer, step, group):
for n in self.metric_names:
m = self.metrics[n].item()
writer.add_scalar('%s/%s' % (group, n), m, step)
def write_to_logger(self, logger, epoch, step):
log_str = 'epoch {:>3}, step {}'.format(epoch, step)
for n in self.metric_names:
m = self.metrics[n].item()
log_str += ', %s %g' % (n, m)
logger.info(log_str)
def items(self):
return self.metrics.items()
def all_reduce(self):
for n in self.metric_names:
torch.distributed.all_reduce(self.metrics[n], op=torch.distributed.ReduceOp.SUM)
self.metrics[n] /= torch.distributed.get_world_size()
def __getitem__(self, k):
return self.metrics[k]
def __setitem__(self, key, value):
self.metrics[key] = value
def __repr__(self):
return self.metrics.__repr__()
def __str__(self):
return self.metrics.__str__()
def keys(self):
return self.metrics.keys()
def cut_eos(seq, eos_id):
if eos_id not in seq:
return seq
return seq[:seq.index(eos_id)]
def infer_model_pred(model, tokz, dataset, outfile, batch_size=30):
max_ans_len = dataset.max_ans_len + 1
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=3, pin_memory=True, collate_fn=PadBatchSeq(0))
device = model.device
with open(outfile, 'w', encoding='utf-8') as f:
with torch.no_grad():
model.eval()
for i, data in enumerate(data_loader):
bs = data['context_id'].shape[0]
context = data['context_id'].to(device, non_blocking=True)
context_lens = data['context_lens'].to(device, non_blocking=True)
mask = get_reverse_mask(context_lens)
output_sequence = model.generate(
input_ids=context, attention_mask=mask, do_sample=False, eos_token_id=tokz.eos_token_id,
pad_token_id=tokz.eos_token_id, max_length=context.shape[1] + max_ans_len, early_stopping=True)
cls_res = output_sequence[:,context.shape[1]:].tolist()
ans = data['ans_id'].tolist()
for i in range(bs):
res = {}
res['context'] = tokz.decode(context[i][-context_lens[i]:])
res['ans_gold'] = tokz.decode(ans[i][:data['ans_lens'][i]-1])
res['ans_pred'] = tokz.decode(cut_eos(cls_res[i], tokz.eos_token_id))
print(json.dumps(res), file=f)
def cal_metrics_from_pred_files(res_file):
with open(res_file, 'r', encoding='utf-8') as f:
res = [json.loads(i) for i in f.readlines()]
y_true = [i['ans_gold'] for i in res]
y_pred = [i['ans_pred'] for i in res]
return {
"accuracy": accuracy_score(y_true, y_pred),
}
def slot_f1_score(pred_slots, true_slots):
'''
pred_slots, true_slots are like [['from_location:10-11', 'leaving_date:12-13']]
'''
slot_types = set([slot.split(":")[0] for row in true_slots for slot in row])
slot_type_f1_scores = []
for slot_type in slot_types:
predictions_for_slot = [[p for p in prediction if slot_type in p] for prediction in pred_slots] # [['from_location'],[],[],['from_location']]
labels_for_slot = [[l for l in label if slot_type in l] for label in true_slots]
proposal_made = [len(p) > 0 for p in predictions_for_slot]
has_label = [len(l) > 0 for l in labels_for_slot]
prediction_correct = [prediction == label for prediction, label in zip(predictions_for_slot, labels_for_slot)]
true_positives = sum([
int(proposed and correct)
for proposed, correct in zip(proposal_made, prediction_correct)])
num_predicted = sum([int(proposed) for proposed in proposal_made])
num_to_recall = sum([int(hl) for hl in has_label])
precision = true_positives / (1e-5 + num_predicted)
recall = true_positives / (1e-5 + num_to_recall)
f1_score = 2 * precision * recall / (1e-5 + precision + recall)
slot_type_f1_scores.append(f1_score)
return np.mean(slot_type_f1_scores)
def get_answer(tokz ,model, example, example_lens, max_ans_len, sampling=True, temperature=0.95, top_k=100, top_p=0.95):
model.eval()
device = 'cuda'
# context = example['context_id'].to(device, non_blocking=True)
# context_lens = example['context_lens'].to(device, non_blocking=True)
# print('context',context.size(),flush=True)
mask = get_reverse_mask(example_lens)
eos_token = tokz.eos_token_id
# pad_token = tokz.pad_token
pad_token = tokz.eos_token_id
# only gpt2
output_seq = model.decoder.generate(input_ids=example, attention_mask=mask, do_sample=False, eos_token_id=eos_token, pad_token_id=pad_token,
max_length=example.shape[1]+max_ans_len, early_stopping=True)
return output_seq[:, example.shape[1]:]
# mask = example['context_mask'].to(device, non_blocking=True)
# print('eval input:',example[0],flush=True)
# print('eval mask:',mask[0],flush=True)
logits, mem = model.transformer(input_ids=example, attention_mask=mask)
batch_size = logits.size()[0]
# print('batch size',batch_size,flush=True)
# print('context',context.size(),flush=True)
# prev = torch.tensor([[eos_token]] * batch_size, dtype=torch.long, device=device)
prev = example[..., -1].view(batch_size, -1)
output = prev
probability = torch.tensor([], dtype=torch.float, device=device)
if_end = torch.tensor([False] * batch_size, dtype=torch.bool, device=device)
# logits = model.transformer.lm_head(logits)
# print('prev is padding?',prev,flush=True)
for i in range(0, max_ans_len):
one_col = torch.tensor([[True]]*batch_size, dtype=torch.bool, device=device)
mask = torch.cat((mask, one_col), dim=1)
# print('mask',mask.size(),flush=True)
# logits, mem = model.transformer(input_ids=prev, past=mem) # mem is a tuple, doesn't have the size
logits, mem = model.transformer(input_ids=prev, past=mem, attention_mask=mask)
# print('logits',logits.size(),flush=True)
logits = model.lm_head(logits)
# print('mem',mem.size(),'mask',mask.size(),'prev',prev.size(),flush=True)
logits = logits[:,-1,:]/temperature
logits = top_k_top_p_filtering(logits, top_k, top_p)
probs = F.softmax(logits, dim=-1)
if sampling:
next_token = torch.multinomial(probs, num_samples=1)
else:
_, next_token = torch.topk(probs, k=1, dim=-1)
probability = torch.cat((probability, probs.gather(1, next_token)), dim=1)
output = torch.cat((output, next_token), dim=1)
prev = next_token
if_end[next_token.view(-1).eq(eos_token)] = True
if if_end.all(): break
# print('output',output.size(),flush=True)
return output[:,1:]
def textid_decode(text, eos, tokz):
if eos in text:
text = text[:text.index(eos)]
text = tokz.decode(text).strip()
return text
def padding_convert(text_list, eos):
tt_list = []
for text in text_list:
if eos in text:
eos_indexs = [i for i, x in enumerate(text) if x==eos]
# print('eos index', eos_indexs,flush=True)
if len(eos_indexs)>1: text = text[:eos_indexs[1]]
tt_list.append(text)
tt_lens = [len(i) for i in tt_list]
tt_lens = torch.tensor(tt_lens, dtype=torch.long).to('cuda')
tt_pad = torch.tensor([pad_seq(i, eos, max(tt_lens), pad_left=True) for i in tt_list], dtype=torch.long).to('cuda')
return tt_pad, tt_lens
def gen_pseudo_data(model, task, dataset, max_output_len=90, batch_size=30, target_count=100, output_file=None,
top_k=100, top_p=0.95, temperature=1, only_decoder=False):
device = 'cuda'
prompt_id = [dataset.tokz.bos_token_id] + dataset.pseudo_data_prompt_id + [dataset.tokz.eos_token_id]
prompt_id = torch.LongTensor([prompt_id for _ in range(batch_size)]).to(device)
# prompt length is 15 for intent detection atis dataset.
prompt_mask, prompt_lens = None, None
ans_prompt_id = dataset.pseudo_ans_prompt_id
ans_prompt_id = torch.LongTensor([ans_prompt_id for _ in range(batch_size)]).to(device)
res = []
utter_set = set()
eos_token = dataset.tokz.eos_token_id
if os.path.exists(output_file):
os.remove(output_file)
if output_file is not None:
if not os.path.isdir(os.path.dirname(output_file)):
os.makedirs(os.path.dirname(output_file), exist_ok=True)
genefile = open(output_file,'w', encoding='utf8')
while len(res) < target_count:
with torch.no_grad():
model.eval()
output_seq = sample_sequence(model, dataset.tokz, length=max_output_len, batch_size=batch_size,
x_tokens=prompt_id, x_mask=prompt_mask, x_lens=prompt_lens, temperature=temperature,
top_k=top_k, top_p=top_p, sampling=True, only_decoder=only_decoder)
output_list = output_seq.tolist()
# print(prompt_id.size(),ans_prompt_id.size(),output_seq.size(),flush=True)
# * prompt (w/o eos) + generated utterence + answer_prompt as the pseudo input for the LM.
# lm_input_list = torch.cat((prompt_id[..., :-1], output_seq, ans_prompt_id), dim=1).tolist()
# Change right padding to left padding.
# lm_input, lm_input_lens = padding_convert(lm_input_list, eos_token)
# if not only_decoder:
# label_batch = get_answer(eos_token, model, lm_input, lm_input_lens, max_ans_len=10).tolist()
for i in range(batch_size):
# print('LM INPUT::', dataset.tokz.decode(lm_input[i]), flush=True)
# print('UTTER BEFORE::', dataset.tokz.decode(output_list[i]), flush=True)
output = textid_decode(output_list[i], eos_token, dataset.tokz)
if 'Answer:' in output:
regex = re.compile('(.+) \"\? Answer:')
utter1 = regex.search(output)
if utter1 is not None:
utter = utter1.group(1)
utter_id = dataset.tokz.encode(utter)
else:
utter = output
utter_id = output_list[i]
# * Get labels.
if not only_decoder:
lm_input = [torch.cat((prompt_id[i,:-1], output_seq[i,:], ans_prompt_id[i,:])).tolist()]
lm_input, lm_input_lens = padding_convert(lm_input, eos_token)
label_id = get_answer(tokz, model, lm_input, lm_input_lens, max_ans_len=10).tolist()[0] # ? Not finished.
# label_id = label_batch[i]
label = textid_decode(label_id, eos_token, dataset.tokz)
elif 'Answer:' in output:
label = re.findall('(?<=Answer: ).*$', output)[0]
if utter is not None and label != '':
print('UTTER::', utter,'====>> LABEL::', label, flush=True)
if utter not in utter_set:
utter_set.add(utter) # avoid duplicate utterance
res.append([utter, label])
print(json.dumps({'Utterence': utter, 'Label': label}, ensure_ascii=False), file=genefile, flush=True)
res = res[:target_count] # only output the first target_count utterances
return res[:target_count]
def infer_batch_pseudo_data(model, dataset, max_output_len=90, batch_size=30):
prompt_id = [dataset.tokz.bos_token_id] + dataset.pseudo_data_prompt_id
max_output_len += len(prompt_id)
prompt_id = torch.LongTensor(
[prompt_id for _ in range(batch_size)]).to(model.device)
with torch.no_grad():
model.eval()
# output_seq = model.generate(input_ids=prompt_id, do_sample=True, eos_token_id=dataset.tokz.eos_token_id,
# pad_token_id=dataset.tokz.eos_token_id, max_length=max_output_len, early_stopping=True)
output_seq, _ = sample_sequence(model, dataset.tokz, length=max_output_len, batch_size=batch_size,
x_mask=x_mask, x_tokens=x_tokens, temperature=temperature,
top_k=top_k, top_p=top_p, eos_token=dataset.tokz.eos_token_id, device=device )
output_seq = output_seq.tolist()
res = []
for i in range(batch_size):
output = output_seq[i][1:]
if dataset.tokz.eos_token_id in output:
output = output[:output.index(dataset.tokz.eos_token_id)]
output = dataset.tokz.decode(output)
output = dataset.parse_pseudo_data(output)
if output is not None:
res.append(output)
return res
def strip_list(seq, eos_id):
l, r = 0, len(seq)-1
for i in range(len(seq)):
if seq[i] != eos_id:
break
l = i
for i in range(len(seq)-1, -1, -1):
if seq[i] != eos_id:
break
r = i
return seq[l+1:r]