786 lines
42 KiB
Python
786 lines
42 KiB
Python
from mycvae.utils import *
|
|
from mycvae.model import *
|
|
import threading
|
|
import torch
|
|
import os, shutil
|
|
from torch.utils.data import DataLoader
|
|
from dataset import PadBatchSeq, TASK2INFO, MixedCLSDataset, MixedSlotTaggingDataset, PromptCLSDataset, PromptSlotTaggingDataset
|
|
from tqdm import tqdm
|
|
import json
|
|
import torch.distributed as dist
|
|
import os, time, gc, json, pickle, argparse, math
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.utils.data as data
|
|
from torch.nn import DataParallel
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
import numpy as np
|
|
import transformers
|
|
from transformers import get_linear_schedule_with_warmup, Conv1D, AdamW
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from tqdm import tqdm
|
|
import importlib
|
|
import copy
|
|
from apex.optimizers import FusedAdam
|
|
from apex import amp
|
|
from apex.fp16_utils import FP16_Optimizer
|
|
from collections import Counter
|
|
from nltk.translate.bleu_score import sentence_bleu
|
|
from nltk.translate.bleu_score import SmoothingFunction
|
|
from rouge import Rouge
|
|
from sklearn.manifold import TSNE
|
|
import matplotlib
|
|
matplotlib.use('Agg')
|
|
import matplotlib.pyplot as plt
|
|
from dataset import get_datasets, get_dataclass_dict
|
|
# from generate import *
|
|
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Model
|
|
|
|
def compute_vae_loss(device, model, loss_fn, beta, vae_total, distill=False, prev_model=None):
|
|
p_mask, p_tokens, px_mask, px_tokens, input_tokens, attention_mask, target_tokens, input_label_mask = vae_total
|
|
input_tokens = input_tokens.to(device)
|
|
target_tokens = target_tokens.to(device)
|
|
attention_mask = attention_mask.to(device)
|
|
p_mask = p_mask.to(device)
|
|
p_tokens = p_tokens.to(device)
|
|
px_mask = px_mask.to(device)
|
|
px_tokens = px_tokens.to(device)
|
|
|
|
# * Compute CVAE outputs.
|
|
# print('x',x_tokens.size(),'x_mask',x_mask.size(),'y',y_tokens.size(),'y_mask',y_mask.size(),'input',input_tokens.size(),'target',target_tokens.size(),flush=True)
|
|
# print('input_label_mask',input_label_mask.size(),flush=True)
|
|
outputs = model(input_ids=input_tokens, attention_mask=attention_mask, p_mask=p_mask, p_tokens=p_tokens, px_mask=px_mask, px_tokens=px_tokens)
|
|
logits = outputs[0]
|
|
kl_loss = outputs[-1]
|
|
num_logits = logits.size(-1)
|
|
|
|
# Perform masking
|
|
if attention_mask is not None:
|
|
attention_mask = attention_mask.type(torch.bool)
|
|
attention_mask = attention_mask.to(device)
|
|
# NOT compute losses on padding tokens.
|
|
# logits = logits.masked_select(mask.unsqueeze(-1))
|
|
# target_tokens = target_tokens.masked_select(mask)
|
|
loss_mask = (input_label_mask>0).view(-1)
|
|
|
|
if distill:
|
|
# * Teacher outputs.
|
|
assert prev_model is not None
|
|
prev_model.eval()
|
|
with torch.no_grad():
|
|
teacher_outputs = prev_model(input_ids=input_tokens, attention_mask=attention_mask, p_mask=p_mask, p_tokens=p_tokens, px_mask=px_mask, px_tokens=px_tokens)
|
|
teacher_logits = teacher_outputs[0]
|
|
teacher_logits = teacher_logits.contiguous()
|
|
ce_loss = loss_fn(logits.view(-1, num_logits), target_tokens.view(-1), teacher_logits.view(-1, teacher_logits.size(-1)))
|
|
else:
|
|
ce_loss = loss_fn(logits.view(-1, num_logits), target_tokens.view(-1))
|
|
|
|
# NOT compute losses on pre-defined prompt tokens.
|
|
ce_loss_masked = ce_loss.where(loss_mask.cuda(), torch.tensor(0.0).cuda())
|
|
ce_loss = ce_loss_masked.sum() / loss_mask.sum()
|
|
|
|
kl_loss = kl_loss.mean()
|
|
loss = ce_loss + beta * kl_loss * 0.01
|
|
return loss, ce_loss, kl_loss
|
|
|
|
def compute_lm_loss(device, model, loss_fn, lm_total, distill=False, prev_model=None):
|
|
lm_input_tokens, lm_attention_mask, lm_input_label_mask, lm_target_tokens = lm_total
|
|
lm_input_tokens, lm_attention_mask, lm_input_label_mask, lm_target_tokens = lm_input_tokens.to(device), lm_attention_mask.to(device), lm_input_label_mask.to(device),lm_target_tokens.to(device)
|
|
|
|
# * Compute LM outputs and CE loss for the VAE decoder w/o 'add_attn'.
|
|
model=model.to(device)
|
|
model.decoder.train()
|
|
outputs = model.decoder(input_ids=lm_input_tokens)
|
|
logits = outputs.logits.contiguous()
|
|
|
|
# Perform mask on labels
|
|
loss_mask = (lm_input_label_mask>0).view(-1)
|
|
loss_padding_mask = (lm_attention_mask>0).view(-1)
|
|
logits = logits.view(-1, logits.size(-1))
|
|
# print('logits',torch.argmax(logits,dim=2).tolist(),flush=True)
|
|
# print('target',lm_target_tokens.view(-1).tolist(),flush=True)
|
|
if distill:
|
|
# * Teacher outputs.
|
|
assert prev_model is not None
|
|
prev_model.eval()
|
|
with torch.no_grad():
|
|
tc_outputs = prev_model.decoder(input_ids=lm_input_tokens)
|
|
teacher_logits = tc_outputs.logits.contiguous()
|
|
ce_loss = loss_fn(logits, lm_target_tokens.view(-1), teacher_logits.view(-1, teacher_logits.size(-1)))
|
|
else:
|
|
ce_loss = loss_fn(logits, lm_target_tokens.view(-1))
|
|
|
|
lm_loss_masked = ce_loss.where(loss_mask.cuda(), torch.tensor(0.0).cuda())
|
|
lm_pad_loss_masked = ce_loss.where(loss_padding_mask.cuda(), torch.tensor(0.0).cuda())
|
|
qa_loss = lm_loss_masked.sum() / loss_mask.sum()
|
|
lm_loss = lm_pad_loss_masked.sum() / loss_padding_mask.sum()
|
|
|
|
loss = qa_loss + 0.5 * lm_loss
|
|
|
|
return loss
|
|
|
|
def train_step(device, model, optimizer, loss_fn, beta, vae_total, lm_total, distill=False, only_decoder=False, only_vae=False, prev_model=None):
|
|
output = []
|
|
optimizer.zero_grad()
|
|
if only_decoder:
|
|
vae_loss, vae_ce_loss, vae_kl_loss = torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
|
|
else:
|
|
vae_loss, vae_ce_loss, vae_kl_loss = compute_vae_loss(device, model, loss_fn, beta, vae_total=vae_total, distill=distill, prev_model=prev_model)
|
|
|
|
lm_loss = compute_lm_loss(device, model, loss_fn, lm_total, distill=distill, prev_model=prev_model)
|
|
|
|
if not only_decoder and not only_vae:
|
|
total_loss = vae_loss + 0.5 * lm_loss
|
|
elif only_vae:
|
|
total_loss = vae_loss
|
|
else:
|
|
total_loss = lm_loss
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) # max_grad_norm=1.0
|
|
total_loss.backward()
|
|
# print('loss gradient',total_loss.grad,flush=True)
|
|
# for p in model.parameters():
|
|
# print(p.grad.norm(),flush=True)
|
|
# torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0) # max_grad_norm=1.0
|
|
|
|
optimizer.step()
|
|
output.append((vae_loss.item(), vae_ce_loss.mean().item(), vae_kl_loss.item(), lm_loss.item()))
|
|
|
|
return output
|
|
|
|
def get_model_input(batch, input_type='vae', step=100, tokz=None):
|
|
if input_type == 'vae':
|
|
px_tokens = batch['posterior_id'][..., :-1]
|
|
px_mask = batch['posterior_mask'][..., :-1].contiguous()
|
|
p_tokens = batch['prompt_id']
|
|
p_mask = batch['prompt_mask'].contiguous()
|
|
input_tokens = batch['input_id'][...,:-1].contiguous()
|
|
attn_mask = batch['input_mask'][..., :-1].contiguous()
|
|
input_label_mask = batch['input_label_mask'][..., 1:].contiguous()
|
|
tgt_tokens = batch['input_id'][..., 1:].contiguous()
|
|
input_total = (p_mask, p_tokens, px_mask, px_tokens, input_tokens, attn_mask, tgt_tokens, input_label_mask)
|
|
if step < 3:
|
|
print('*'*10,'VAE','*'*10, flush=True)
|
|
print('prefix', tokz.decode(p_tokens[0]), flush=True)
|
|
print("input_id", px_tokens[0], flush=True)
|
|
print('input', tokz.decode(px_tokens[0]), flush=True)
|
|
print('input_mask', px_mask[0], flush=True)
|
|
print('target_id', tgt_tokens[0], flush=True)
|
|
print('target', tokz.decode(tgt_tokens[0]), flush=True)
|
|
# print('\n',flush=True)
|
|
else:
|
|
all_tokens = batch['all_id'][..., :-1]
|
|
all_mask = batch['all_mask'][..., :-1].contiguous()
|
|
all_tgt_tokens = batch['all_id'][..., 1:].contiguous()
|
|
all_label_mask = batch['all_label_mask'][..., 1:].contiguous()
|
|
input_total = (all_tokens, all_mask, all_label_mask, all_tgt_tokens)
|
|
if step < 3:
|
|
print('*'*10, 'LM', '*'*10, flush=True)
|
|
print('all id (utterance and answer)', all_tokens[0], flush=True)
|
|
print('all', tokz.decode(all_tokens[0]), flush=True)
|
|
print('all mask', all_mask[0], flush=True)
|
|
print('all target tokens', all_tgt_tokens[0], flush=True)
|
|
print('all tgt', tokz.decode(all_tgt_tokens[0]), flush=True)
|
|
print('all label mask (qa)', all_label_mask[0], flush=True)
|
|
print('\n', flush=True)
|
|
|
|
return input_total
|
|
|
|
|
|
class Trainer:
|
|
def __init__(self, args, tokz, datasets, logger=None, cache_dir=None, memory=None):
|
|
# Other
|
|
self.args = args
|
|
self.datasets = datasets
|
|
self.logger = logger
|
|
self.rank = self.args.local_rank
|
|
self.device = self.args.device
|
|
self.global_step = 0
|
|
self.task_step = 0
|
|
distributed = False if self.rank == -1 else True
|
|
self.task_dict = {k:v for v,k in enumerate(self.args.tasks)}
|
|
|
|
# Tensorboard log
|
|
if self.rank in [0, -1]: # Only write logs on master node
|
|
self.train_writer = SummaryWriter(os.path.join(args.tb_log_dir, 'train'), flush_secs=10)
|
|
self.valid_writer = SummaryWriter(os.path.join(args.tb_log_dir, 'valid'))
|
|
|
|
self.config = GPT2Config()
|
|
# print('GPT2 Config:',self.config,flush=True)
|
|
# self.tokz = tokenizer
|
|
self.tokz = tokz
|
|
print('len tokz gpt2',len(self.tokz),flush=True)
|
|
|
|
self.cache_dir = cache_dir
|
|
|
|
# Set model we use.
|
|
# self.model = self.set_model(self.config, add_input=self.args.add_input, add_attn=self.args.add_attn)
|
|
# * If we use memory
|
|
if self.args.use_memory:
|
|
self.memory = memory
|
|
else:
|
|
self.memory = None
|
|
|
|
# * Load datasets.
|
|
self.data_loaders = {}
|
|
for task in self.datasets:
|
|
self.data_loaders[task] = {}
|
|
if distributed:
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(datasets[task]['train'])
|
|
valid_sampler = torch.utils.data.distributed.DistributedSampler(datasets[task]['val'])
|
|
else:
|
|
train_sampler = torch.utils.data.RandomSampler(datasets[task]['train'])
|
|
valid_sampler = None
|
|
|
|
self.data_loaders[task]['train'] = DataLoader(datasets[task]['train'], batch_size=self.args.train_batch_size,
|
|
sampler=train_sampler, num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.eos_token_id))
|
|
self.data_loaders[task]['val'] = DataLoader(datasets[task]['val'], batch_size=self.args.eval_batch_size,
|
|
sampler=valid_sampler, num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.eos_token_id))
|
|
self.data_loaders[task]['test'] = DataLoader(datasets[task]['test'], batch_size=self.args.eval_batch_size,
|
|
sampler=valid_sampler, num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.eos_token_id))
|
|
# randomness
|
|
np.random.seed(args.seed)
|
|
prng = np.random.RandomState()
|
|
torch.random.manual_seed(args.seed)
|
|
gpu = not self.args.no_gpu
|
|
if gpu: torch.cuda.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed)
|
|
|
|
self.beta = args.beta_0
|
|
self.loss_fn = nn.CrossEntropyLoss(reduction='none').to(self.device, non_blocking=True)
|
|
self.kd_loss = KDLoss(KD_term=self.args.KD_term, T=self.args.KD_temperature)
|
|
|
|
# Save paths.
|
|
# self.save_folder = os.path.join(args.output_dir, args.experiment)
|
|
self.save_folder = args.output_dir
|
|
os.makedirs(self.save_folder, exist_ok=True)
|
|
self.output_encoder_dir = self.save_folder
|
|
self.output_decoder_dir = self.save_folder
|
|
|
|
def _train(self, model, curr_task, prev_tasks, all_tasks_res, prev_model=None):
|
|
# * CVAE model, optimizer setting.
|
|
if model is None:
|
|
assert len(prev_tasks) == 0
|
|
model_path='./gpt2'
|
|
VAE = CVAEModel(self.config, self.args)
|
|
VAE.initialize(model_path)
|
|
self.logger.info("Successfully initialize the model with pretrained GPT2 model!")
|
|
|
|
else:
|
|
VAE = model
|
|
VAE = VAE.to(self.device, non_blocking=True)
|
|
VAE.train()
|
|
|
|
if self.args.add_kd:
|
|
if prev_model is None:
|
|
assert len(prev_tasks) == 0
|
|
prev_VAE = PrevCVAEModel(self.config, self.args)
|
|
prev_VAE.initialize(model_path)
|
|
prev_VAE.load_state_dict(VAE.state_dict())
|
|
else:
|
|
prev_VAE = prev_model
|
|
prev_VAE.to('cuda')
|
|
prev_VAE.eval()
|
|
|
|
# print('NUMBER of model parameters.',len(list(model.parameters())),flush=True)
|
|
print("NUMBER of PARAMs",sum(p.numel() for p in VAE.parameters()),flush=True)
|
|
|
|
# ! Prepare training datasets.
|
|
train_dataset = self.datasets[curr_task]['train']
|
|
|
|
if self.args.gen_replay and len(prev_tasks)>0:
|
|
# * Generate pseudo data from previous model
|
|
pseudo_data_count = int(len(self.datasets[curr_task]['train']) * self.args.pseudo_data_ratio) // len(prev_tasks)
|
|
if self.rank in [0, -1]:
|
|
self.logger.info(f' pseudo data count for each task {pseudo_data_count}')
|
|
|
|
inferred_CLS_data = {}
|
|
inferred_ST_data = {}
|
|
for task in prev_tasks:
|
|
# ! Generate pseudo data for old observed tasks.
|
|
pseudo_output_file = os.path.join(self.args.output_dir, curr_task+'_pseudo_'+task+'.json')
|
|
self.logger.info('Generated pseudo will be writen into '+str(pseudo_output_file))
|
|
data = gen_pseudo_data(VAE, task, self.datasets[task]['train'], max_output_len=self.args.ctx_max_len,
|
|
batch_size=self.args.eval_batch_size, target_count=pseudo_data_count,
|
|
output_file=pseudo_output_file,
|
|
top_k=self.args.top_k, top_p=self.args.top_p, temperature=self.args.temperature,
|
|
only_decoder=self.args.only_decoder, memory=self.memory, args=self.args)
|
|
|
|
if TASK2INFO[task]['task_type'] == 'CLS':
|
|
inferred_CLS_data[task] = data
|
|
elif TASK2INFO[task]['task_type'] == 'SlotTagging':
|
|
inferred_ST_data[task] = data
|
|
else:
|
|
pass
|
|
if self.rank in [0, -1]: # Log out sample data
|
|
self.logger.info(f'Inferring pseudo data from {task}')
|
|
for i in range(0, min(6, pseudo_data_count)):
|
|
self.logger.info(f' {data[i]}')
|
|
|
|
if TASK2INFO[curr_task]['task_type'] == 'SlotTagging':
|
|
prev_train_dataset = MixedSlotTaggingDataset(inferred_ST_data, tokz=self.tokz, ctx_max_len=self.args.ctx_max_len)
|
|
elif TASK2INFO[curr_task]['task_type'] == 'CLS':
|
|
prev_train_dataset = MixedCLSDataset(inferred_CLS_data, tokz=self.tokz, ctx_max_len=self.args.ctx_max_len)
|
|
|
|
self.logger.info("Begin training!"+ "=" * 40)
|
|
self.logger.info(f'Currently training {curr_task}'+'-'*10)
|
|
|
|
evalout_dir = os.path.join(self.args.output_dir, curr_task)
|
|
if os.path.exists(evalout_dir) and os.path.isdir(evalout_dir):
|
|
shutil.rmtree(evalout_dir)
|
|
os.makedirs(evalout_dir, exist_ok=True)
|
|
|
|
# Optimizer weight decay
|
|
param_optimizer = list(VAE.named_parameters())
|
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
|
optimizer_grouped_parameters = [
|
|
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': self.args.weight_decay},
|
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
|
]
|
|
optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.lr, correct_bias=True)
|
|
optimizer.zero_grad()
|
|
if self.rank == -1:
|
|
train_sampler = torch.utils.data.RandomSampler(train_dataset) # not distributed
|
|
else:
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) # distributed
|
|
# * Get train loader.
|
|
train_loader = DataLoader(train_dataset, batch_size=self.args.train_batch_size, sampler=train_sampler,
|
|
num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.eos_token_id))
|
|
if len(prev_tasks)>0:
|
|
print('We have trained %d tasks'%(len(prev_tasks)),flush=True)
|
|
prev_train_sampler = torch.utils.data.RandomSampler(prev_train_dataset)
|
|
prev_train_loader = DataLoader(prev_train_dataset, batch_size=self.args.train_batch_size, sampler=prev_train_sampler,
|
|
num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.eos_token_id))
|
|
|
|
n_iter = self.args.num_train_epochs[curr_task] * len(train_loader)
|
|
beta_list = frange_cycle_linear(n_iter, start=0.0, n_cycle=self.args.num_cycle, ratio=0.9)
|
|
self.logger.info("Beta list we will use to train"+str(beta_list))
|
|
|
|
t_total = len(train_loader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs[curr_task]
|
|
save_steps = int(t_total // self.args.num_train_epochs[curr_task] * self.args.save_epochs)
|
|
if self.args.warmup_steps > 0:
|
|
num_warmup_steps = self.args.warmup_steps
|
|
else:
|
|
num_warmup_steps = int(t_total * self.args.warmup_proportion)
|
|
if self.args.eval_times_per_task > 0:
|
|
eval_steps = int(t_total / self.args.eval_times_per_task)
|
|
else:
|
|
eval_steps = self.args.eval_steps
|
|
if not self.args.nouse_scheduler: # if True, not change lr.
|
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, t_total)
|
|
|
|
|
|
# TODO: Training ==============================================================================
|
|
for epoch in range(1, self.args.num_train_epochs[curr_task] + 1):
|
|
if self.rank in [0, -1]:
|
|
self.logger.info(
|
|
f'num_warmup_steps: {num_warmup_steps}, t_total: {t_total}, save_steps: {save_steps}, eval_steps: {eval_steps}')
|
|
|
|
self.task_step, iter_count, avg_loss = 0, 0, 0
|
|
self.logger.info("Total epoches: %d" % self.args.num_train_epochs[curr_task])
|
|
self.logger.info('*'*50+"Epoch: %d" % epoch+'*'*50)
|
|
st = time.time()
|
|
|
|
ITER = tqdm(enumerate(train_loader), dynamic_ncols=True, total=len(train_loader))
|
|
if len(prev_tasks) > 0:
|
|
prev_ITER = enumerate(prev_train_loader)
|
|
# prev_ITER = tqdm(enumerate(prev_train_loader), dynamic_ncols=True, total=len(prev_train_loader))
|
|
|
|
# * For loop
|
|
for i, data in ITER:
|
|
optimizer.zero_grad()
|
|
beta = beta_list[i + (epoch-1) * len(ITER)]
|
|
iter_count += 1
|
|
|
|
# * Get inputs of the CVAE model.
|
|
#! x [prompt] for prior; y [prompt + input] for posterior; input_token use only prompt; target is the utter_id. 1211 modified.
|
|
if not self.args.general_prompt:
|
|
vae_total = get_model_input(data, input_type='vae', step=self.global_step, tokz=self.tokz)
|
|
lm_total = get_model_input(data, input_type='lm', step=self.global_step, tokz=self.tokz)
|
|
|
|
else:
|
|
p_mask = data['gene_prompt_mask']
|
|
p_tokens = data['gene_prompt_id']
|
|
px_tokens = data['gene_posterior_id']
|
|
px_mask = data['gene_posterior_mask']
|
|
input_tokens = data['gene_input_id'][..., :-1]
|
|
attention_mask = data['gene_input_mask'][...,:-1].contiguous()
|
|
input_label_mask = data['gene_input_label_mask'][...,1:].contiguous()
|
|
target_tokens = data['gene_input_id'][...,1:].contiguous()
|
|
# * Get inputs of the LM model (CVAE decoder).
|
|
lm_input_tokens = data['gene_all_id'][...,:-1]
|
|
lm_attention_mask = data['gene_all_mask'][...,:-1].contiguous()
|
|
lm_input_label_mask = data['gene_all_label_mask'][...,1:].contiguous()
|
|
lm_target_tokens = data['gene_all_id'][...,1:].contiguous()
|
|
lm_total = (lm_input_tokens, lm_attention_mask, lm_input_label_mask, lm_target_tokens)
|
|
vae_total = (p_mask, p_tokens, px_mask, px_tokens, input_tokens, attention_mask, target_tokens, input_label_mask)
|
|
|
|
# * Get CVAE + LM outputs.
|
|
all_output = train_step(self.device, VAE, optimizer, self.loss_fn, beta,
|
|
vae_total=vae_total, lm_total=lm_total, only_decoder=self.args.only_decoder,
|
|
only_vae=self.args.only_vae, distill=False, prev_model=None)
|
|
|
|
vae_loss, vae_ce_loss, vae_kl_loss, lm_loss = all_output[-1]
|
|
|
|
# * Distill the pseudo samples of previous tasks.
|
|
if self.args.add_kd:
|
|
if len(prev_tasks) > 0:
|
|
# print('prev_iter',len(prev_ITER),flush=True)
|
|
if not self.args.general_prompt:
|
|
prev_data = next(iter(prev_train_loader))
|
|
prev_vae_total = get_model_input(prev_data, input_type='vae', step=self.global_step, tokz=self.tokz)
|
|
prev_lm_total = get_model_input(prev_data, input_type='lm', step=self.global_step, tokz=self.tokz)
|
|
else:
|
|
prev_data = next(iter(prev_train_loader))
|
|
prev_lm_input_tokens = prev_data['gene_all_id'][...,:-1]
|
|
prev_lm_attn_mask = prev_data['gene_all_mask'][...,:-1].contiguous()
|
|
prev_lm_input_label_mask = prev_data['gene_all_label_mask'][...,1:].contiguous()
|
|
prev_lm_target_tokens = prev_data['gene_all_id'][...,1:].contiguous()
|
|
prev_p_mask = prev_data['gene_prompt_mask']
|
|
prev_p_tokens = prev_data['gene_prompt_id']
|
|
prev_px_tokens = prev_data['gene_posterior_id']
|
|
prev_px_mask = prev_data['gene_posterior_mask']
|
|
prev_input_tokens = prev_data['gene_input_id'][..., :-1]
|
|
prev_attention_mask = prev_data['gene_input_mask'][...,:-1].contiguous()
|
|
prev_input_label_mask = prev_data['gene_input_label_mask'][..., 1:].contiguous()
|
|
prev_target_tokens = prev_data['gene_input_id'][..., 1:].contiguous()
|
|
|
|
prev_vae_total = (prev_p_mask, prev_p_tokens, prev_px_mask, prev_px_tokens, prev_input_tokens, prev_attention_mask, prev_target_tokens, prev_input_label_mask)
|
|
prev_lm_total = (prev_lm_input_tokens, prev_lm_attn_mask, prev_lm_input_label_mask, prev_lm_target_tokens)
|
|
|
|
_ = train_step(self.device, VAE, optimizer, self.kd_loss, beta, vae_total=prev_vae_total, lm_total=prev_lm_total,
|
|
only_vae=False, distill=True, prev_model=prev_VAE)
|
|
|
|
optimizer.step()
|
|
# Update the learning rate.
|
|
lr = scheduler.get_last_lr()[0]
|
|
|
|
# Write in the logger.
|
|
self.logger.info(f'EPOCH: {epoch}, task step: {self.task_step}, global step: {self.global_step} ' + \
|
|
'CVAE total loss: %.4f, CVAE ce loss: %.4f, CVAE kl loss: %.4f, LM loss: %.4f' % (vae_loss, vae_ce_loss, vae_kl_loss, lm_loss))
|
|
|
|
# Log CVAE info to Tensorboard
|
|
self.train_writer.add_scalar('vae_loss_total', vae_loss, self.global_step)
|
|
self.train_writer.add_scalar('vae_ce_loss', vae_ce_loss, self.global_step)
|
|
self.train_writer.add_scalar('vae_kl_loss', vae_kl_loss, self.global_step)
|
|
self.train_writer.add_scalar('lm_loss', lm_loss, self.global_step)
|
|
self.train_writer.add_scalar('time_per_batch', time.time() - st, self.global_step)
|
|
# self.train_writer.add_scalar('ppl', math.exp(min(vae_ce_loss, 10)), self.global_step)
|
|
self.train_writer.add_scalar('lr', lr, self.global_step)
|
|
self.train_writer.add_scalar('beta', beta, self.global_step)
|
|
|
|
st = time.time()
|
|
if not self.args.nouse_scheduler:
|
|
scheduler.step() # Can change the lr.
|
|
|
|
self.global_step += 1
|
|
self.task_step += 1
|
|
|
|
# TODO: Evaluate model on Valid datasets. ================================================================================================
|
|
if self.global_step % self.args.eval_steps == 0 and self.args.eval_steps > 0:
|
|
all_res_per_eval = {}
|
|
avg_metric_loss, avg_metric = [], []
|
|
for val_task in self.args.tasks[:len(prev_tasks)+1]:
|
|
eval_loss = self._eval(VAE, val_task, self.kd_loss)
|
|
all_prior_info = get_all_priors(VAE, self.tokz, self.args)
|
|
# * Compute Accuracy or F1.
|
|
if 'CLS' == TASK2INFO[val_task]['task_type']: # classification task
|
|
eval_metric = self._eval_acc(VAE, val_task, all_prior_info=all_prior_info, out_dir=os.path.join(self.args.output_dir, curr_task) if self.args.log_eval_res else None)
|
|
elif 'SlotTagging' == TASK2INFO[val_task]['task_type']: # slot-filling task
|
|
eval_metric = self._eval_f1(VAE, val_task, all_prior_info=all_prior_info, out_dir=os.path.join(self.args.output_dir, curr_task) if self.args.log_eval_res else None)
|
|
else:
|
|
raise ValueError('Eval metric not defined for task %s' % val_task)
|
|
|
|
avg_metric_loss.append(eval_loss['lm_loss'])
|
|
if 'CLS' == TASK2INFO[val_task]['task_type']:
|
|
avg_metric.append(eval_metric['acc'])
|
|
elif 'SlotTagging' == TASK2INFO[val_task]['task_type']:
|
|
avg_metric.append(eval_metric['f1'])
|
|
else:
|
|
pass
|
|
|
|
for key, value in eval_loss.items():
|
|
self.valid_writer.add_scalar(f'{val_task}/{key}', value, self.global_step)
|
|
for key, value in eval_metric.items():
|
|
self.valid_writer.add_scalar(f'{val_task}/{key}', value, self.global_step)
|
|
|
|
self.logger.info('Evaluation! '+f'Current task: {curr_task}, task_epoch: {epoch}, task step: {self.task_step}, ' + \
|
|
f'global step: {self.global_step}, {val_task}: eval loss: {eval_loss}, eval metric: {eval_metric}')
|
|
if 'CLS' == TASK2INFO[val_task]['task_type']:
|
|
all_res_per_eval[val_task] = {'intent_acc': eval_metric['acc']}
|
|
else:
|
|
all_res_per_eval[val_task] = {'slot_f1': eval_metric['f1']}
|
|
# print(all_res_per_eval,flush=True)
|
|
all_tasks_res.append(all_res_per_eval)
|
|
|
|
self.valid_writer.add_scalar(f'avg/loss', sum(avg_metric_loss)/len(avg_metric_loss), self.global_step)
|
|
self.valid_writer.add_scalar(f'avg/metric', sum(avg_metric)/len(avg_metric), self.global_step)
|
|
|
|
# * Save model checkpoint.
|
|
if self.global_step % (5 * self.args.eval_steps) == 0:
|
|
task_save_folder = os.path.join(self.save_folder, str(curr_task))
|
|
os.makedirs(task_save_folder, exist_ok=True)
|
|
task_model_save_path = os.path.join(task_save_folder, 'model_'+'{:03d}'.format(self.global_step) + '.pt')
|
|
torch.save(VAE.state_dict(), task_model_save_path)
|
|
self.logger.info("Saving model checkpoint of %s to %s"%(curr_task, task_model_save_path))
|
|
|
|
self.logger.info("Training loop. The %dth epoch completed."%epoch)
|
|
|
|
# * Generate reconstructed inputs during training.
|
|
if epoch % 20 == 0 and epoch != 0:
|
|
if self.args.generate_dur_train:
|
|
generate(VAE, test_loader, self.save_folder, epoch)
|
|
self.logger.info('Finished generating reconstructed utterences (inputs).')
|
|
|
|
# * Store latent variables z of current task into the memory.
|
|
if self.args.use_memory and self.memory:
|
|
latent_info = {}
|
|
curr_data = self.datasets[curr_task]['train']
|
|
p_tokens = [curr_data.tokz.bos_token_id] + curr_data.pseudo_data_prompt_id + [curr_data.tokz.eos_token_id]
|
|
p_tokens = torch.LongTensor(p_tokens).to(self.device)
|
|
prior_out = VAE.encoder(input_ids=p_tokens)
|
|
prior_emb, _ = VAE.avg_attn(prior_out[0])
|
|
latent_mean, latent_logvar = VAE.prior_mean(prior_emb), VAE.prior_logvar(prior_emb)
|
|
latent_info['prior'] = (latent_mean, latent_logvar)
|
|
if self.args.save_z:
|
|
prior_latent_z = VAE.reparameterize(latent_mean, latent_logvar)
|
|
prior_latent_proj = self.args.alpha_z * VAE.latent_mlp(prior_latent_z)
|
|
latent_info['prior_z'] = prior_latent_proj
|
|
|
|
# * Random save some postriors.
|
|
latent_info['posterior'] = []
|
|
select_idx_list = random.sample(range(len(ITER)), self.args.memory_size)
|
|
curr_train_dataset = self.datasets[curr_task]['train'].data
|
|
if self.args.save_z: latent_info['posterior_z'] = []
|
|
for idx, dt in enumerate(curr_train_dataset): # save memory_size * curr_batch_size samples' posterior into the memory.
|
|
if idx in select_idx_list:
|
|
px_tokens = torch.tensor(dt['posterior_id']).to(self.device)
|
|
post_out = VAE.encoder(input_ids=px_tokens)
|
|
post_emb, _ = VAE.avg_attn(post_out[0])
|
|
post_mean, post_logvar = VAE.post_mean(post_emb), VAE.post_logvar(post_emb)
|
|
latent_info['posterior'].append((post_mean, post_logvar))
|
|
post_latent_z = VAE.reparameterize(post_mean, post_logvar)
|
|
post_latent_proj = self.args.alpha_z * VAE.latent_mlp(post_latent_z)
|
|
latent_info['posterior_z'].append(post_latent_proj)
|
|
self.memory.push(curr_task, (p_tokens, latent_info))
|
|
|
|
# * Model saving.
|
|
task_final_save_path = os.path.join(self.save_folder, str(curr_task)+'_model'+'.pt')
|
|
self.logger.info('Saving model checkpoint of task %s to %s'%(curr_task, task_final_save_path))
|
|
model_save_path = os.path.join(self.save_folder, 'model_last.pt')
|
|
torch.save(VAE.state_dict(), model_save_path)
|
|
self.logger.info('Saving total model checkpoint to %s', model_save_path)
|
|
self.logger.info('='*25+'Training complete!'+'='*25)
|
|
|
|
if self.args.add_kd:
|
|
return VAE, prev_VAE, all_tasks_res
|
|
return VAE, all_tasks_res
|
|
|
|
def _eval(self, model, task, loss_fn):
|
|
lm_recorder = MetricsRecorder(self.device, 'lm_loss')
|
|
|
|
with torch.no_grad():
|
|
model.eval()
|
|
for i, data in enumerate(self.data_loaders[task]['val']):
|
|
bs = data['all_id'].shape[0]
|
|
lm_total = get_model_input(data, input_type='lm', step=100, tokz=self.tokz)
|
|
loss = compute_lm_loss(self.device, model, loss_fn, lm_total)
|
|
lm_recorder.metric_update(i, {'lm_loss': loss})
|
|
|
|
if self.rank != -1:
|
|
lm_recorder.all_reduce()
|
|
return lm_recorder
|
|
|
|
# * Calculate accuracy of the model
|
|
def _eval_acc(self, model, task, out_dir=None, all_prior_info=None):
|
|
max_ans_len = max(self.datasets[task]['train'].max_ans_len, self.datasets[task]['val'].max_ans_len, self.datasets[task]['test'].max_ans_len) + 1
|
|
# print('Maximum answer length',max_ans_len,flush=True)
|
|
if out_dir is not None:
|
|
out_dir = os.path.join(out_dir, f'{task}_step{self.global_step}.json')
|
|
out_dir_content = []
|
|
|
|
with torch.no_grad():
|
|
model.eval()
|
|
pred_ans_all, gold_ans_all = [], []
|
|
context_all = []
|
|
all_pred_task_name = []
|
|
|
|
# for i, data in enumerate(self.data_loaders[task]['train']):
|
|
for i, data in enumerate(self.data_loaders[task]['val']):
|
|
# * Task incremental learning, know task id during testing.
|
|
if not self.args.classIL:
|
|
example = data['context_id'].to(self.device, non_blocking=True)
|
|
example_lens = data['context_lens'].to(self.device, non_blocking=True)
|
|
elif self.general_prompt:
|
|
example = data['general_context_id'].to(self.device, non_blocking=True)
|
|
example_lens = data['general_context_lens'].to(self.device, non_blocking=True)
|
|
|
|
if out_dir is not None:
|
|
context_all.append(example)
|
|
gold_ans = data['ans_id'].to(self.device, non_blocking=True)
|
|
|
|
# Print model inputs to check.
|
|
# if i==0:
|
|
# print('context:',self.tokz.decode(example[0]),flush=True)
|
|
# print('ans:',self.tokz.decode(gold_ans[0]),flush=True)
|
|
|
|
pred_ans = get_answer(self.tokz, model, example, example_lens, max_ans_len, sampling=False, args=self.args)
|
|
# print('context shape',context.shape,flush=True)
|
|
# pred_ans = pred_ans[:,context.shape[1]:]
|
|
# print('pred',pred_ans,flush=True)
|
|
# print('gold',gold_ans,flush=True)
|
|
|
|
pred_ans_all.append(pred_ans)
|
|
gold_ans_all.append(gold_ans)
|
|
|
|
if self.args.classIL:
|
|
right_pred_task = [1 for pred_task in all_pred_task_name if pred_task==task]
|
|
self.logger.info('Correct prediction of task name ratio is: '+str(len(right_pred_task)/len(all_pred_task_name)))
|
|
|
|
# collect results from different nodes
|
|
pred_ans_all = communicate_tensor(pred_ans_all, pad_token=self.tokz.eos_token_id).tolist()
|
|
gold_ans_all = communicate_tensor(gold_ans_all,pad_token=self.tokz.eos_token_id).tolist()
|
|
if out_dir is not None:
|
|
context_all = communicate_tensor(context_all,pad_token=self.tokz.eos_token_id).tolist()
|
|
|
|
if self.rank in [0, -1]:
|
|
correct_count, total_count = 0, len(pred_ans_all)
|
|
for i in range(total_count):
|
|
if compare_tokens(pred_ans_all[i], gold_ans_all[i], self.tokz.eos_token_id):
|
|
correct_count += 1
|
|
if out_dir is not None:
|
|
res = {}
|
|
res['context'] = self.tokz.decode(strip_list(context_all[i], self.tokz.eos_token_id))
|
|
res['ans_gold'] = self.tokz.decode(cut_eos(gold_ans_all[i], self.tokz.eos_token_id))
|
|
res['ans_pred'] = self.tokz.decode(cut_eos(pred_ans_all[i], self.tokz.eos_token_id))
|
|
out_dir_content.append(res)
|
|
|
|
if out_dir is not None:
|
|
with open(out_dir, 'w', encoding='utf-8') as outfile:
|
|
for res in out_dir_content:
|
|
print(json.dumps(res), file=outfile)
|
|
model.train()
|
|
return {'acc': correct_count / total_count}
|
|
else:
|
|
model.train()
|
|
return None
|
|
|
|
# * Calculate f1 of the model
|
|
def _eval_f1(self, model, task, out_dir=None, all_prior_info=None):
|
|
max_ans_len = max(self.datasets[task]['train'].max_ans_len, self.datasets[task]['val'].max_ans_len, self.datasets[task]['test'].max_ans_len) + 1
|
|
|
|
if out_dir is not None:
|
|
out_dir = os.path.join(out_dir, f'{task}_step{self.global_step}.json')
|
|
out_dir_content = []
|
|
|
|
with torch.no_grad():
|
|
model.eval()
|
|
pred_ans_all, gold_ans_all = [], []
|
|
context_all = []
|
|
|
|
for i, data in enumerate(self.data_loaders[task]['val']):
|
|
# * Task incremental learning, know task id during testing.
|
|
if not self.args.classIL:
|
|
example = data['context_id'].to(self.device, non_blocking=True)
|
|
example_lens = data['context_lens'].to(self.device, non_blocking=True)
|
|
elif self.args.general_prompt:
|
|
example = data['general_context_id'].to(self.device, non_blocking=True)
|
|
example_lens = data['general_context_lens'].to(self.device, non_blocking=True)
|
|
|
|
if out_dir is not None:
|
|
context_all.append(example)
|
|
gold_ans = data['ans_id'].to(self.device, non_blocking=True)
|
|
pred_ans = get_answer(self.tokz, model, example, example_lens, max_ans_len,
|
|
sampling=False, args=self.args)
|
|
|
|
pred_ans_all.append(pred_ans)
|
|
gold_ans_all.append(gold_ans)
|
|
|
|
# collect results from different nodes
|
|
pred_ans_all = communicate_tensor(pred_ans_all, pad_token=self.tokz.eos_token_id).tolist()
|
|
gold_ans_all = communicate_tensor(gold_ans_all,pad_token=self.tokz.eos_token_id).tolist()
|
|
if out_dir is not None:
|
|
context_all = communicate_tensor(context_all,pad_token=self.tokz.eos_token_id).tolist()
|
|
|
|
if self.rank in [0, -1]:
|
|
pred_slots = [[i.strip() for i in self.tokz.decode(cut_eos(l, self.tokz.eos_token_id)).split(';')] for l in pred_ans_all]
|
|
gold_slots = [[i.strip() for i in self.tokz.decode(cut_eos(l, self.tokz.eos_token_id)).split(';')] for l in gold_ans_all]
|
|
pred_slots = [[j for j in i if ':' in j] for i in pred_slots]
|
|
gold_slots = [[j for j in i if ':' in j] for i in gold_slots]
|
|
|
|
for i in range(len(pred_ans_all)):
|
|
res = {}
|
|
res['context'] = self.tokz.decode(strip_list(context_all[i], self.tokz.eos_token_id))
|
|
res['ans_gold'] = self.tokz.decode(cut_eos(gold_ans_all[i], self.tokz.eos_token_id))
|
|
res['ans_pred'] = self.tokz.decode(cut_eos(pred_ans_all[i], self.tokz.eos_token_id))
|
|
out_dir_content.append(res)
|
|
|
|
if out_dir is not None:
|
|
with open(out_dir, 'w', encoding='utf-8') as outfile:
|
|
for res in out_dir_content:
|
|
print(json.dumps(res), file=outfile)
|
|
model.train()
|
|
return {'f1': slot_f1_score(pred_slots, gold_slots)}
|
|
else:
|
|
model.train()
|
|
return None
|
|
|
|
def _slot_f1_score(pred_slots, true_slots):
|
|
slot_types = set([slot.split(":")[0] for row in true_slots for slot in row])
|
|
slot_type_f1_scores = []
|
|
|
|
for slot_type in slot_types:
|
|
predictions_for_slot = [[p for p in prediction if slot_type in p] for prediction in pred_slots]
|
|
labels_for_slot = [[l for l in label if slot_type in l] for label in true_slots]
|
|
|
|
proposal_made = [len(p) > 0 for p in predictions_for_slot]
|
|
has_label = [len(l) > 0 for l in labels_for_slot]
|
|
prediction_correct = [prediction == label for prediction, label in zip(predictions_for_slot, labels_for_slot)]
|
|
true_positives = sum([
|
|
int(proposed and correct)
|
|
for proposed, correct in zip(proposal_made, prediction_correct)])
|
|
|
|
num_predicted = sum([int(proposed) for proposed in proposal_made])
|
|
num_to_recall = sum([int(hl) for hl in has_label])
|
|
|
|
precision = true_positives / (1e-5 + num_predicted)
|
|
recall = true_positives / (1e-5 + num_to_recall)
|
|
|
|
f1_score = 2 * precision * recall / (1e-5 + precision + recall)
|
|
slot_type_f1_scores.append(f1_score)
|
|
|
|
return np.mean(slot_type_f1_scores)
|
|
|
|
# * Train across tasks through time.
|
|
def train(self, tasks):
|
|
model = None
|
|
prev_model = None
|
|
self.logger.info('Total tasks number is '+str(len(tasks)))
|
|
|
|
all_tasks_res = []
|
|
res_file = os.path.join(self.save_folder, 'metrics.json')
|
|
|
|
for i in range(len(tasks)):
|
|
if self.args.add_kd:
|
|
model, prev_model, all_tasks_res = self._train(model, prev_model=prev_model, curr_task=tasks[i], prev_tasks=tasks[:i], all_tasks_res=all_tasks_res)
|
|
else:
|
|
model, all_tasks_res = self._train(model, curr_task=tasks[i], prev_tasks=tasks[:i], all_tasks_res=all_tasks_res)
|
|
|
|
# * Update the previous model with the current model learned from one task.
|
|
if self.args.add_kd:
|
|
prev_model.load_state_dict(model.state_dict())
|
|
|
|
self.logger.info('We have trained %d tasks'%(i+1))
|
|
if self.args.use_memory:
|
|
self.logger.info('Memory has saved information of %d tasks'%(len(self.memory.memory.keys()))+'-'*10)
|
|
|
|
# * Evaluation metrics saving.
|
|
with open(res_file,'w', encoding='utf-8') as f:
|
|
for e in all_tasks_res:
|
|
print(json.dumps(e,ensure_ascii=False), file=f)
|
|
|
|
# * Save memory if use.
|
|
if self.args.use_memory:
|
|
save_memory_path = os.path.join(self.args.memory_path, 'memory.pt')
|
|
torch.save(self.memory, save_memory_path)
|