427 lines
18 KiB
Python
427 lines
18 KiB
Python
|
from mycvae.utils import *
|
||
|
from mycvae.model import *
|
||
|
import pickle
|
||
|
import os
|
||
|
import math
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from torch.nn import DataParallel
|
||
|
import numpy as np
|
||
|
import argparse
|
||
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
|
||
|
from tqdm import tqdm
|
||
|
from tqdm import trange
|
||
|
import importlib
|
||
|
# import logging
|
||
|
import copy
|
||
|
# from data.util import *
|
||
|
from collections import Counter
|
||
|
from nltk.translate.bleu_score import sentence_bleu
|
||
|
from nltk.translate.bleu_score import SmoothingFunction
|
||
|
from rouge import Rouge
|
||
|
from dataset import get_datasets
|
||
|
from torch.utils.data import DataLoader
|
||
|
from dataset import PadBatchSeq, TASK2INFO
|
||
|
|
||
|
|
||
|
# devices = '2'
|
||
|
# os.environ["CUDA_VISIBLE_DEVICES"] = devices
|
||
|
|
||
|
def top_k_top_p_filtering(logits, top_k=100, top_p=0.95, filter_value=-float('Inf')):
|
||
|
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||
|
Args:
|
||
|
logits: logits distribution shape (vocabulary size)
|
||
|
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||
|
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||
|
"""
|
||
|
top_k = min(top_k, logits.size(-1)) # Safety check
|
||
|
if top_k > 0:
|
||
|
# Remove all tokens with a probability less than the last token of the top-k
|
||
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||
|
logits[indices_to_remove] = filter_value
|
||
|
|
||
|
if top_p > 0.0:
|
||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||
|
|
||
|
# Remove tokens with cumulative probability above the threshold
|
||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||
|
# Shift the indices to the right to keep also the first token above the threshold
|
||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||
|
sorted_indices_to_remove[..., 0] = 0
|
||
|
|
||
|
# scatter sorted tensors to original indexing
|
||
|
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||
|
logits[indices_to_remove] = filter_value
|
||
|
|
||
|
return logits
|
||
|
|
||
|
|
||
|
def repeat_score(text, ngram=[3, 4, 5, 6]):
|
||
|
ngram_list = []
|
||
|
for ng in ngram:
|
||
|
ngram_list.append([text[idx:idx + ng] for idx in range(len(text) - ng - 1)])
|
||
|
|
||
|
max_occurs = []
|
||
|
for ngrams in ngram_list:
|
||
|
count_result = Counter([' '.join(n) for n in ngrams])
|
||
|
try:
|
||
|
max_occurs.append(
|
||
|
max(count_result.values())
|
||
|
)
|
||
|
except:
|
||
|
pass
|
||
|
|
||
|
scores = [max_oc / ((len(text) / ngram[idx]) + ngram[idx]) for idx, max_oc in enumerate(max_occurs)]
|
||
|
return max(scores) if len(scores) >= 1 else 1.0
|
||
|
|
||
|
def get_reverse_mask(x_len):
|
||
|
mask = torch.arange(max(x_len)-1, -1, -1, device=x_len.device)[None, :] < x_len[:, None] # [bs, max_len]
|
||
|
return mask.bool()
|
||
|
|
||
|
def sample_sequence(model, tokenizer, length, batch_size=None, x_mask=None, x_tokens=None, x_lens=None,
|
||
|
temperature=1, top_k=100, top_p=0.95, sampling=True, only_decoder=False):
|
||
|
device = 'cuda'
|
||
|
x_tokens = x_tokens.to(device)
|
||
|
|
||
|
if x_mask is not None: x_mask = x_mask.to(device)
|
||
|
if x_lens is not None:
|
||
|
x_reverse_mask = get_reverse_mask(x_lens)
|
||
|
else:
|
||
|
x_reverse_mask = None
|
||
|
eos_token = tokenizer.eos_token_id
|
||
|
|
||
|
# mem = None
|
||
|
# prev = torch.tensor([[eos_token]] * batch_size, dtype=torch.long, device=device)
|
||
|
# prev = prev.view(batch_size, -1)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
if not only_decoder:
|
||
|
prior_mean, prior_logvar = model.encoder_prior(input_ids=x_tokens, attention_mask=x_mask)[:2]
|
||
|
latent_mean, latent_logvar = prior_mean, prior_logvar
|
||
|
z = model.reparameterize(latent_mean, latent_logvar)
|
||
|
assert not torch.isnan(z).any(), 'training get nan z'
|
||
|
add_attn=True
|
||
|
else:
|
||
|
z = None
|
||
|
add_attn=False
|
||
|
|
||
|
# Z: [bs, 768]
|
||
|
# _, mem = model.transformer(input_ids=x_tokens[:, :-1], past=None, attention_mask=x_reverse_mask, representations=z) # x_tokens--prompt
|
||
|
_, mem = model.transformer(input_ids=x_tokens[:, :-1], past=None, representations=z, add_attn=add_attn) # x_tokens--prompt
|
||
|
|
||
|
prev = x_tokens[..., -2].view(batch_size, -1)
|
||
|
# print('prev',prev,flush=True)
|
||
|
output = prev
|
||
|
probability = torch.tensor([], dtype=torch.float, device=device)
|
||
|
if_end = torch.tensor([False] * batch_size, dtype=torch.bool, device=device)
|
||
|
# print('z size',z.size(),flush=True)
|
||
|
|
||
|
for i in range(length): #trange
|
||
|
# print('prev', prev.size(),flush=True)
|
||
|
# one_col = torch.tensor([[True]] * batch_size, dtype=torch.bool, device=device)
|
||
|
# if x_reverse_mask is not None:
|
||
|
# x_reverse_mask = torch.cat((x_reverse_mask, one_col), dim=1)
|
||
|
|
||
|
logits, mem = model.transformer(input_ids=prev, past=mem, representations=z, add_attn=add_attn)
|
||
|
# logits, mem = model.transformer(input_ids=prev, past=mem, attention_mask=x_reverse_mask,representations=z)
|
||
|
|
||
|
logits = model.lm_head(logits)
|
||
|
|
||
|
logits = logits[:, -1, :] / temperature
|
||
|
logits = top_k_top_p_filtering(logits, top_k, top_p)
|
||
|
probs = F.softmax(logits, dim=-1)
|
||
|
|
||
|
if sampling:
|
||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||
|
else:
|
||
|
_, next_token = torch.topk(probs, k=1, dim=-1)
|
||
|
|
||
|
probability = torch.cat((probability, probs.gather(1, next_token)), dim=1)
|
||
|
output = torch.cat((output, next_token), dim=1)
|
||
|
prev = next_token
|
||
|
|
||
|
# early stopping if all sents have ended once
|
||
|
if_end[next_token.view(-1).eq(eos_token)] = True
|
||
|
if if_end.all(): break
|
||
|
|
||
|
return output[:,1:]
|
||
|
# return output
|
||
|
|
||
|
|
||
|
def decode_text(text, eos, tokenizer):
|
||
|
text = text[text.index(eos) + 1:]
|
||
|
if eos in text:
|
||
|
idx = text.index(eos)
|
||
|
text = text[:idx]
|
||
|
text=tokenizer.decode(text).strip()
|
||
|
|
||
|
return text
|
||
|
|
||
|
|
||
|
def run_model():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
# CHANGED
|
||
|
parser.add_argument('--experiment', type=str)
|
||
|
parser.add_argument("--data_dir", default="./PLL_DATA/",
|
||
|
type=str, help="The path to train/dev/test data files.") # Default parameters are set based on single GPU training
|
||
|
parser.add_argument('--tasks',nargs='+', default=['banking'])
|
||
|
parser.add_argument('--model_path', type=str, help='pretrained model path to local checkpoint')
|
||
|
# NON-CHANGED
|
||
|
parser.add_argument("--seed", type=int, default=0)
|
||
|
parser.add_argument("--batch_size", type=int, default=50)
|
||
|
parser.add_argument("--length", type=int, default=-1)
|
||
|
parser.add_argument("--temperature", type=int, default=0.95)
|
||
|
parser.add_argument('--top_p', type=float, default=0.95)
|
||
|
parser.add_argument('--top_k', type=int, default=100)
|
||
|
parser.add_argument('--output_dir', type=str, default='out')
|
||
|
|
||
|
parser.add_argument('--data_type', type=str, default='t1', choices=['t' + str(i) for i in range(9)], help="t: type")
|
||
|
parser.add_argument('--model_type', type=str, default='cvae', choices=['cvae', 'ae_vae_fusion'])
|
||
|
|
||
|
# use GPU
|
||
|
parser.add_argument('--gpu', default=0, type=int)
|
||
|
parser.add_argument('--no_gpu', action="store_true")
|
||
|
|
||
|
parser.add_argument('--add_input', action="store_true")
|
||
|
parser.add_argument('--add_attn', action="store_true")
|
||
|
parser.add_argument('--add_softmax', action="store_true")
|
||
|
parser.add_argument('--attn_proj_vary', action="store_true")
|
||
|
|
||
|
parser.add_argument('--learn_prior', action="store_true")
|
||
|
parser.add_argument("--eval_batch_size", default=100, type=int, help="Batch size per GPU/CPU for evaluation.")
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
args.log_file = os.path.join(args.output_dir, 'log.txt')
|
||
|
print(args)
|
||
|
|
||
|
args.model_type = 'cvae'
|
||
|
args.learn_prior = True
|
||
|
|
||
|
# GPU
|
||
|
if not torch.cuda.is_available(): args.no_gpu = True
|
||
|
gpu = not args.no_gpu
|
||
|
if gpu: torch.cuda.set_device(args.gpu)
|
||
|
device = torch.device(args.gpu if gpu else "cpu")
|
||
|
|
||
|
# randomness
|
||
|
np.random.seed(args.seed)
|
||
|
prng = np.random.RandomState()
|
||
|
torch.random.manual_seed(args.seed)
|
||
|
if gpu: torch.cuda.manual_seed(args.seed)
|
||
|
|
||
|
if args.batch_size == -1:
|
||
|
args.batch_size = 1
|
||
|
|
||
|
# logging
|
||
|
save_folder = os.path.join(args.output_dir, args.experiment)
|
||
|
os.makedirs(save_folder, exist_ok=True)
|
||
|
|
||
|
# importlib.reload(logging)
|
||
|
logger = get_logger(args.log_file)
|
||
|
logger.info('\n----------------------------------------------------------------------')
|
||
|
|
||
|
print('Loading models...')
|
||
|
cache_dir = os.path.join(args.output_dir, 'model_cache')
|
||
|
os.makedirs(cache_dir, exist_ok=True)
|
||
|
# Load pre-trained teacher tokenizer (vocabulary)
|
||
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir)
|
||
|
tokenizer.max_len = int(1e12)
|
||
|
gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=cache_dir)
|
||
|
print('gpt2_params:', num_params(gpt2_model)) # gpt2: 124439808
|
||
|
config = GPT2Config()
|
||
|
|
||
|
# * Load VAE model
|
||
|
VAE = VAEModel(config, add_input=args.add_input, add_attn=args.add_attn, add_softmax=args.add_softmax,
|
||
|
attn_proj_vary=args.attn_proj_vary, learn_prior=args.learn_prior)
|
||
|
|
||
|
args.load = args.model_path
|
||
|
print('Loading model weights...')
|
||
|
state = torch.load(args.load)
|
||
|
VAE.load_state_dict(state)
|
||
|
print('='*25,'Load trained model successfully.','='*25,flush=True)
|
||
|
gc.collect()
|
||
|
print('VAE_params:', num_params(VAE)) # 286694400
|
||
|
|
||
|
print('Setup data...',flush=True)
|
||
|
seq_len = VAE.config.n_ctx
|
||
|
VAE.config.n_ctx=100 # xiu setting.
|
||
|
print('VAE config n_ctx: sequence length',seq_len,flush=True)
|
||
|
|
||
|
# * Get test loader.
|
||
|
datasets = get_datasets(args.data_dir, args.tasks, tokenizer, num_workers=1, ctx_max_len=100)
|
||
|
for task in datasets:
|
||
|
# data_loaders[task] = {}
|
||
|
# train_sampler = torch.utils.data.RandomSampler(datasets[task]['train'])
|
||
|
# train_loader = DataLoader(datasets[task]['train'],batch_size=args.train_batch_size,sampler=train_sampler,
|
||
|
# num_workers=1, pin_memory=True, collate_fn=PadBatchSeq(tokenizer.eos_token_id))
|
||
|
# val_loader = DataLoader(datasets[task]['val'],batch_size=args.eval_batch_size,sampler=None,
|
||
|
# num_workers=1, pin_memory=True, collate_fn=PadBatchSeq(tokenizer.eos_token_id))
|
||
|
test_loader = DataLoader(datasets[task]['test'],batch_size=args.eval_batch_size,sampler=None,
|
||
|
num_workers=1, pin_memory=True, collate_fn=PadBatchSeq(tokenizer.eos_token_id))
|
||
|
|
||
|
VAE.eval() # be careful about VAE.eval() vs VAE.train()
|
||
|
VAE.to(device)
|
||
|
|
||
|
logger.info('\n----------------------------------------------------------------------')
|
||
|
logger.info("Testing loop. batches: %d" % len(test_loader))
|
||
|
|
||
|
endoftext = tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
||
|
|
||
|
n_samples = 0
|
||
|
bleu4_sum = 0.0
|
||
|
rouge_scores_values_sum = [0.0] * 9
|
||
|
|
||
|
model_type = args.model_type
|
||
|
|
||
|
check_file = open(os.path.join(save_folder, 'check.txt'), 'w', encoding='utf8')
|
||
|
with tqdm(total=len(test_loader)) as pbar:
|
||
|
ITER = enumerate(test_loader)
|
||
|
for i, data in ITER:
|
||
|
x_mask = data['prompt_mask']
|
||
|
x_tokens = data['prompt_id']
|
||
|
# x_lens = data['prompt_lens']
|
||
|
x_lens = torch.tensor([len(i)-1 for i in x_tokens],dtype=torch.long).to(device)
|
||
|
|
||
|
length = args.length
|
||
|
# Changed to train.py
|
||
|
if length == -1:
|
||
|
# length = VAE.config.n_ctx - 1
|
||
|
length = VAE.config.n_ctx - x_tokens.size(1) - 1
|
||
|
# elif length > VAE.config.n_ctx - 1:
|
||
|
elif length > VAE.config.n_ctx - x_tokens.size(1) - 1:
|
||
|
raise ValueError("Can't get samples longer than window size: %s" % VAE.config.n_ctx)
|
||
|
|
||
|
target_tokens = data['input_id'][..., 1:].contiguous()
|
||
|
eff_samples = []
|
||
|
n, l = target_tokens.size()
|
||
|
# storys = [tokenizer.decode(target_tokens[i, :]) for i in range(n)]
|
||
|
# storys = [s[s.find("<|endoftext|>") + len("<|endoftext|>"):] for s in storys]
|
||
|
# storys_str = [s[:s.find("<|endoftext|>") + len("<|endoftext|>")] if "<|endoftext|>" in s else s for s in storys]
|
||
|
|
||
|
for _ in range(1):
|
||
|
all_out = []
|
||
|
sample_time=5 # Sample 5 times.
|
||
|
for zz in range(sample_time):
|
||
|
out = sample_sequence(
|
||
|
model=VAE,
|
||
|
tokenizer=tokenizer,
|
||
|
length=length,
|
||
|
batch_size=x_tokens.size()[0],
|
||
|
x_mask=x_mask,
|
||
|
x_tokens=x_tokens,
|
||
|
x_lens=x_lens,
|
||
|
temperature=args.temperature,
|
||
|
top_k=args.top_k,
|
||
|
top_p=args.top_p,
|
||
|
)
|
||
|
out = out.tolist()
|
||
|
all_out.append(out)
|
||
|
|
||
|
# * Check latent code z ability.
|
||
|
for ss in range(len(all_out[0])):
|
||
|
check_file.write('\n')
|
||
|
check_file.write('='*20+'SAMPLE: %d'%n_samples+'='*20)
|
||
|
check_file.write('\n')
|
||
|
iii=1
|
||
|
for oout in all_out:
|
||
|
text = decode_text(oout[ss], endoftext, tokenizer)
|
||
|
check_file.write(str(iii)+'==:')
|
||
|
check_file.write(text)
|
||
|
check_file.write('\n')
|
||
|
iii+=1
|
||
|
check_file.write('-'*100)
|
||
|
check_file.flush()
|
||
|
|
||
|
# text = decode_text(out[ss], endoftext, tokenizer)
|
||
|
# check_file.write('-'*100)
|
||
|
# check_file.write('\n')
|
||
|
# check_file.write('SAMPLE: %d'%n_samples)
|
||
|
# check_file.write('='*2)
|
||
|
# check_file.write(':')
|
||
|
# check_file.write(text)
|
||
|
# check_file.write('\n')
|
||
|
# check_file.write('-'*100)
|
||
|
# check_file.flush()
|
||
|
|
||
|
|
||
|
# tx = decode_text(oout[ss], endoftext, tokenizer)
|
||
|
# # print(tx,flush=True)
|
||
|
# check_file.write(str(iii))
|
||
|
# check_file.write('====')
|
||
|
# check_file.write(tx)
|
||
|
# check_file.write('\n')
|
||
|
# iii+=1
|
||
|
n_samples+=1
|
||
|
eff_samples.append((text,0))
|
||
|
|
||
|
if n_samples>=100: # Only evaluate part of the test data.
|
||
|
break
|
||
|
|
||
|
# bleu, rouge score calculation
|
||
|
# score for one long text, higher than 0.075 usually means repetition
|
||
|
# rep_score = repeat_score(text.split(), ngram=[3, 4, 5, 6, 7, 8])
|
||
|
# if rep_score > 0.075:
|
||
|
# # print(rep_score)
|
||
|
# continue
|
||
|
|
||
|
# try:
|
||
|
# # check bleu
|
||
|
# bleu4 = sentence_bleu([storys_str[i].split()], text, smoothing_function=SmoothingFunction().method7)
|
||
|
|
||
|
# # check rouge
|
||
|
# rouge = Rouge()
|
||
|
# rouge_scores = rouge.get_scores(text, storys_str[i])
|
||
|
# rouge_scores_values = [v for k in rouge_scores[0].keys() for v in rouge_scores[0][k].values()]
|
||
|
|
||
|
# bleu4_sum += bleu4
|
||
|
# rouge_scores_values_sum = [v1 + v2 for v1, v2 in zip(rouge_scores_values_sum, rouge_scores_values)]
|
||
|
# n_samples += 1
|
||
|
# except:
|
||
|
# bleu4 = 0.0
|
||
|
# rouge_scores = [{'rouge-1': {'f': 0.0, 'p': 0.0, 'r': 0.0},
|
||
|
# 'rouge-2': {'f': 0.0, 'p': 0.0, 'r': 0.0},
|
||
|
# 'rouge-l': {'f': 0.0, 'p': 0.0, 'r': 0.0}}]
|
||
|
# Store generated sentences.
|
||
|
# eff_samples.append((text,0))
|
||
|
# eff_samples.append((text, bleu4, rouge_scores))
|
||
|
|
||
|
# write samples to file
|
||
|
# for i in range(len(eff_samples)):
|
||
|
# samples_file.write("=" * 50 + " SAMPLE " + str(i) + " " + "=" * 50)
|
||
|
# samples_file.write('\n' * 2)
|
||
|
|
||
|
# samples_file.write("=" * 40 + " Outlines " + "=" * 40)
|
||
|
# samples_file.write('\n' * 2)
|
||
|
# samples_file.write(tokenizer.decode(x_tokens[i, :][x_mask[i, :] == 1].tolist()))
|
||
|
# samples_file.write('\n' * 2)
|
||
|
# samples_file.write("=" * 40 + " Story " + "=" * 40)
|
||
|
# samples_file.write('\n' * 2)
|
||
|
# samples_file.write(storys_str[i])
|
||
|
# samples_file.write('\n' * 2)
|
||
|
|
||
|
# samples_file.write("=" * 40 + " Generated " + "=" * 40)
|
||
|
# samples_file.write('\n' * 2)
|
||
|
# samples_file.write(eff_samples[i][0])
|
||
|
# samples_file.write('\n' * 1)
|
||
|
# samples_file.flush()
|
||
|
|
||
|
logger.info('batch %d finished.'%n_samples)
|
||
|
pbar.update(1)
|
||
|
|
||
|
print('Test complete with %d samples.' % n_samples)
|
||
|
logger.info("Test complete with %d samples."%n_samples)
|
||
|
|
||
|
# bleu4 = round(bleu4_sum / n_samples, 3)
|
||
|
# rouge_scores_values = [round(r / n_samples, 3) for r in rouge_scores_values_sum]
|
||
|
# print(' bleu-4:', bleu4)
|
||
|
# print(' rouge :', rouge_scores_values)
|
||
|
# logger.info(' bleu-4: %f', bleu4)
|
||
|
# logger.info(' rouge : %s', str(rouge_scores_values))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
run_model()
|