add: ssll
This commit is contained in:
parent
1034f23eec
commit
fa8479e963
|
@ -0,0 +1,40 @@
|
|||
# SSLL
|
||||
## Introduction
|
||||
-----
|
||||
SSLL (Semi-Supervised Lifelong Language Learning) is build by Conversational AI Team, Alibaba DAMO Academy.
|
||||
|
||||
The corresponding paper has been published at EMNLP 2022 Findings: "***Semi-Supervised Lifelong Language Learning***".
|
||||
|
||||
## SSLL Implementation
|
||||
-----
|
||||
### Requirements
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
### Dataset
|
||||
The datasets used in the experiments follows [LAMOL](https://arxiv.org/abs/1909.03329).
|
||||
|
||||
### Model Training and Evaluation
|
||||
```
|
||||
sh scripts/lltrain.sh
|
||||
```
|
||||
The files required for training are under the folder `unifymodel`.
|
||||
|
||||
## Citation
|
||||
----
|
||||
If you use our code or find SSLL useful for your work, please cite our paper as:\
|
||||
@inproceedings{zhao2022semi,\
|
||||
title={Semi-Supervised Lifelong Language Learning},
|
||||
author={Zhao, Yingxiu and Zheng, Yinhe and Yu, Bowen and Tian, Zhiliang and Lee, Dongkyu and Sun, Jian and Li, Yongbin and Zhang, Nevin L.},
|
||||
booktitle={Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing: Findings},
|
||||
year={2022}\
|
||||
}
|
||||
|
||||
LAMOL citation:\
|
||||
@inproceedings{sun2019lamol,\
|
||||
title={LAMOL: LAnguage MOdeling for Lifelong Language Learning},
|
||||
author={Sun, Fan-Keng and Ho, Cheng-Hao and Lee, Hung-Yi},
|
||||
booktitle={International Conference on Learning Representations},
|
||||
year={2020}\
|
||||
}
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
from transformers import MarianMTModel, MarianTokenizer
|
||||
import torch
|
||||
|
||||
# target_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
|
||||
# tgt_tokz_path='out_BT/tgt_tokz'
|
||||
# tgt_model_path='out_BT/tgt_model'
|
||||
# target_tokenizer = MarianTokenizer.from_pretrained(tgt_tokz_path)
|
||||
# target_model = MarianMTModel.from_pretrained(tgt_model_path)
|
||||
|
||||
# # en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
|
||||
# en_tokz_path='out_BT/en_tokz'
|
||||
# en_model_path='out_BT/en_model'
|
||||
# en_tokenizer = MarianTokenizer.from_pretrained(en_tokz_path)
|
||||
# en_model = MarianMTModel.from_pretrained(en_model_path)
|
||||
|
||||
# target_tokenizer.save_pretrained('out_BT/tgt_tokz.pt')
|
||||
# target_model.save_pretrained('out_BT/tgt_model.pt')
|
||||
# en_tokenizer.save_pretrained('out_BT/en_tokz.pt')
|
||||
# en_model.save_pretrained('out_BT/en_model.pt')
|
||||
|
||||
# target_model, en_model = target_model.to('cuda'), en_model.to('cuda')
|
||||
|
||||
def translate(texts, model, tokenizer, language="fr"):
|
||||
# Prepare the text data into appropriate format for the model
|
||||
def template(
|
||||
text): return f"{text}" if language == "en" else f">>{language}<< {text}"
|
||||
src_texts = [template(text) for text in texts]
|
||||
|
||||
# Tokenize the texts
|
||||
encoded = tokenizer.prepare_seq2seq_batch(src_texts)
|
||||
# input_ids = tokenizer.encode(src_texts)
|
||||
for k, v in encoded.items():
|
||||
encoded[k] = torch.tensor(v)
|
||||
encoded = encoded.to('cuda')
|
||||
|
||||
# Generate translation using model
|
||||
translated = model.generate(**encoded)
|
||||
|
||||
# Convert the generated tokens indices back into text
|
||||
translated_texts = tokenizer.batch_decode(
|
||||
translated, skip_special_tokens=True)
|
||||
|
||||
return translated_texts
|
||||
|
||||
|
||||
def back_translate(texts, target_model, target_tokenizer, en_model, en_tokenizer, source_lang="en", target_lang="fr"):
|
||||
# Translate from source to target language
|
||||
fr_texts = translate(texts, target_model, target_tokenizer,
|
||||
language=target_lang)
|
||||
|
||||
# Translate from target language back to source language
|
||||
back_translated_texts = translate(fr_texts, en_model, en_tokenizer,
|
||||
language=source_lang)
|
||||
|
||||
return back_translated_texts
|
||||
|
||||
# en_texts = ['This is so cool', 'I hated the food', 'They were very helpful']
|
||||
# # en_texts = 'I love you.'
|
||||
# aug_texts = back_translate(en_texts, target_model, target_tokenizer, en_model, en_tokenizer, source_lang="en", target_lang="fr")
|
||||
# print(aug_texts)
|
|
@ -0,0 +1,348 @@
|
|||
import torch
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import numpy as np
|
||||
from settings import parse_args
|
||||
from eda import *
|
||||
# from train import special_tokens, special_token_ids
|
||||
|
||||
args = parse_args()
|
||||
class LBIDDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, task_name, tokz, data_path, ctx_max_len=100, special_token_ids=None):
|
||||
self.tokz = tokz
|
||||
self.data_path = data_path
|
||||
# self.CLS_PROMPT = CLSPrompt()
|
||||
# self.SlotTagging_PROMPT = SlotTaggingPrompt()
|
||||
self.max_ans_len = 0
|
||||
self.special_token_ids = special_token_ids
|
||||
|
||||
with open(data_path, "r", encoding='utf-8') as f:
|
||||
ori_data = [json.loads(i) for i in f.readlines()]
|
||||
data = []
|
||||
for i in ori_data:
|
||||
data += self.parse_example(i)
|
||||
|
||||
self.data = []
|
||||
if len(data) > 0:
|
||||
self.data = self.data_tokenization(task_name, data)
|
||||
|
||||
def get_answer(self, intent): # get answer text from intent
|
||||
# replace - and _ to make the text seems more natural
|
||||
ans = intent.replace("-", " ").replace("_", " ")
|
||||
return ans
|
||||
|
||||
def parse_example(self, example):
|
||||
text = example['userInput']['text']
|
||||
ans = self.get_answer(example['intent'])
|
||||
num_aug = 0 # Number of augmented sentences per sample.
|
||||
if not args.meantc or num_aug == 0:
|
||||
# print('Run this.', flush=True)
|
||||
return [(text, ans)]
|
||||
else:
|
||||
aug_text_list = eda(text, num_aug=num_aug)
|
||||
res = [(aug_text, ans) for aug_text in aug_text_list]
|
||||
res.append((text, ans))
|
||||
return res
|
||||
|
||||
def per_tokenization(self, task_name, d):
|
||||
# print(d,flush=True)
|
||||
input_text, ans_text = d
|
||||
input_id = self.tokz.encode(input_text)
|
||||
ans_id = self.tokz.encode(ans_text)
|
||||
self.max_ans_len = max(self.max_ans_len, len(ans_id) + 1)
|
||||
return {
|
||||
'all_id': [self.tokz.bos_token_id] + input_id + [self.special_token_ids['ans_token']] + ans_id + [self.tokz.eos_token_id],
|
||||
'input_id': [self.tokz.bos_token_id] + input_id,
|
||||
'context_id': [self.tokz.bos_token_id] + input_id + [self.special_token_ids['ans_token']],
|
||||
'ans_id': ans_id + [self.tokz.eos_token_id]}
|
||||
|
||||
def data_tokenization(self, task_name, data):
|
||||
# print('data',data[:10],flush=True)
|
||||
data = [self.per_tokenization(task_name, i) for i in data]
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
|
||||
class ULBIDDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, task_name, tokz, data_path, ctx_max_len=100, special_token_ids=None):
|
||||
self.tokz = tokz
|
||||
self.data_path = data_path
|
||||
self.max_ans_len = 0
|
||||
self.special_token_ids = special_token_ids
|
||||
|
||||
with open(data_path, "r", encoding='utf-8') as f:
|
||||
data = [json.loads(i) for i in f.readlines()]
|
||||
data = [self.parse_example(i) for i in data] # [utter, label]
|
||||
|
||||
self.data = []
|
||||
if len(data) > 0:
|
||||
self.data = self.data_tokenization(task_name, data)
|
||||
|
||||
def get_answer(self, intent): # get answer text from intent
|
||||
# replace - and _ to make the text seems more natural
|
||||
ans = intent.replace("-", " ").replace("_", " ")
|
||||
return ans
|
||||
|
||||
def parse_example(self, example):
|
||||
text = example['userInput']['text']
|
||||
ans = self.get_answer(example['intent'])
|
||||
return text, ans
|
||||
|
||||
def per_tokenization(self, task_name, d):
|
||||
input_text, ans_text = d
|
||||
input_id = self.tokz.encode(input_text)
|
||||
ans_id = self.tokz.encode(ans_text)
|
||||
self.max_ans_len = max(self.max_ans_len, len(ans_id) + 1)
|
||||
return {
|
||||
'all_id': [self.tokz.bos_token_id] + input_id + [self.special_token_ids['ans_token']] + ans_id + [self.tokz.eos_token_id],
|
||||
'input_id': [self.tokz.bos_token_id] + input_id,
|
||||
# 'input_id': [self.tokz.bos_token_id] + input_id + [self.special_token_ids['ans_token']],
|
||||
'context_id': [self.tokz.bos_token_id] + input_id + [self.special_token_ids['ans_token']],
|
||||
'ans_id': ans_id + [self.tokz.eos_token_id]}
|
||||
|
||||
def data_tokenization(self, task_name, data):
|
||||
data = [self.per_tokenization(task_name, i) for i in data]
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
|
||||
class PinnedBatch:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __getitem__(self, k):
|
||||
return self.data[k]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.data[key] = value
|
||||
|
||||
def pin_memory(self):
|
||||
for k in self.data.keys():
|
||||
self.data[k] = self.data[k].pin_memory()
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
return self.data.__repr__()
|
||||
|
||||
def __str__(self):
|
||||
return self.data.__str__()
|
||||
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
|
||||
def pad_seq(seq, pad, max_len, pad_left=False):
|
||||
if pad_left:
|
||||
return [pad] * (max_len - len(seq)) + seq
|
||||
else:
|
||||
return seq + [pad] * (max_len - len(seq))
|
||||
|
||||
|
||||
class PadBatchSeq:
|
||||
def __init__(self, pad_id=0):
|
||||
self.pad_id = pad_id
|
||||
|
||||
def __call__(self, batch):
|
||||
input_id = [i['input_id'] for i in batch]
|
||||
context_id = [i['context_id'] for i in batch]
|
||||
all_id = [i['all_id'] for i in batch]
|
||||
ans_id = [i['ans_id'] for i in batch]
|
||||
input_lens = [len(i) for i in input_id]
|
||||
context_lens = [len(i) for i in context_id]
|
||||
all_lens = [len(i) for i in all_id]
|
||||
ans_lens = [len(i) for i in ans_id]
|
||||
|
||||
input_mask = torch.ByteTensor([[1] * input_lens[i] + [0] * (max(input_lens)-input_lens[i]) for i in range(len(input_id))])
|
||||
all_mask = torch.ByteTensor([[1] * all_lens[i] + [0] * (max(all_lens)-all_lens[i]) for i in range(len(all_id))])
|
||||
all_label_mask = torch.ByteTensor([[0] * (context_lens[i]) + [1] * (all_lens[i] - context_lens[i]) + [0] * (max(all_lens)-all_lens[i]) for i in range(len(all_id))]) # Only calculate losses on answer tokens.
|
||||
|
||||
res = {}
|
||||
res["input_id"] = torch.tensor([pad_seq(i, self.pad_id, max(input_lens)) for i in input_id], dtype=torch.long)
|
||||
res["context_id"] = torch.tensor([pad_seq(i, self.pad_id, max(context_lens), pad_left=True) for i in context_id], dtype=torch.long)
|
||||
res["all_id"] = torch.tensor([pad_seq(i, self.pad_id, max(all_lens)) for i in all_id], dtype=torch.long)
|
||||
res["ans_id"] = torch.tensor([pad_seq(i, self.pad_id, max(ans_lens)) for i in ans_id], dtype=torch.long)
|
||||
|
||||
res["input_lens"] = torch.tensor(input_lens, dtype=torch.long)
|
||||
res["context_lens"] = torch.tensor(context_lens, dtype=torch.long)
|
||||
res["all_lens"] = torch.tensor(all_lens, dtype=torch.long)
|
||||
res["ans_lens"] = torch.tensor(ans_lens, dtype=torch.long)
|
||||
|
||||
res['input_mask'] = input_mask
|
||||
res['all_mask'] = all_mask
|
||||
res['all_label_mask'] = all_label_mask
|
||||
|
||||
return PinnedBatch(res)
|
||||
|
||||
|
||||
# * Information for different tasks.
|
||||
if args.data_type == 'intent':
|
||||
TASK2INFO = {
|
||||
"banking": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "banking",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"clinc": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "clinc",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"hwu": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "hwu",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"atis": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "atis",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"tod": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "FB_TOD_SF",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"snips": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "snips",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"top_split1": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "top_split1",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"top_split2": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "top_split2",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"top_split3": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "top_split3",
|
||||
"task_type": "CLS",
|
||||
}
|
||||
}
|
||||
else:
|
||||
TASK2INFO = {
|
||||
"dstc8": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "dstc8",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"restaurant8": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "restaurant8",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"atis": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "atis",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"mit_movie_eng": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "mit_movie_eng",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"mit_movie_trivia": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "mit_movie_trivia",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"mit_restaurant": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "mit_restaurant",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"snips": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "snips",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_unlabel_data(path, task):
|
||||
info = TASK2INFO[task]
|
||||
data_path = os.path.join(path, info['dataset_folder'],'unlabel_train.json')
|
||||
return data_path
|
||||
|
||||
def get_unlabel_dict(path, task, tokz, args):
|
||||
info = TASK2INFO[task]
|
||||
special_token_ids = {'ans_token':tokz.convert_tokens_to_ids('ANS:')}
|
||||
if args.data_type == 'intent':
|
||||
return LBIDDataset(task, tokz, path, args.ctx_max_len, special_token_ids=special_token_ids)
|
||||
else:
|
||||
return None
|
||||
|
||||
class MixedDataset(LBIDDataset):
|
||||
def __init__(self, unlabel_data, label_data, tokz, ctx_max_len=100):
|
||||
self.ctx_max_len = ctx_max_len
|
||||
self.tokz = tokz
|
||||
self.max_ans_len = 0
|
||||
self.special_token_ids = {'ans_token':tokz.convert_tokens_to_ids('ANS:')}
|
||||
|
||||
self.data = []
|
||||
self.data += unlabel_data
|
||||
self.data += label_data
|
||||
|
||||
class AugDataset(LBIDDataset):
|
||||
def __init__(self, task_name, tokz, label_data_path, unlabel_data_path, ctx_max_len=100, special_token_ids=None):
|
||||
self.tokz = tokz
|
||||
self.max_ans_len = 0
|
||||
self.special_token_ids = special_token_ids
|
||||
with open(label_data_path, "r", encoding='utf-8') as f:
|
||||
ori_data = [json.loads(i) for i in f.readlines()]
|
||||
with open(unlabel_data_path, "r", encoding='utf-8') as f:
|
||||
ori_data += [json.loads(i) for i in f.readlines()]
|
||||
data = []
|
||||
for i in ori_data:
|
||||
data += self.parse_example(i)
|
||||
|
||||
self.data = []
|
||||
if len(data) > 0:
|
||||
self.data = self.data_tokenization(task_name, data)
|
||||
|
||||
def parse_example(self, example):
|
||||
text = example['userInput']['text']
|
||||
aug_text_list = eda(text, num_aug=args.num_aug)
|
||||
ans = self.get_answer(example['intent'])
|
||||
res = [(aug_text, ans) for aug_text in aug_text_list]
|
||||
res.append((text, ans))
|
||||
return res
|
||||
|
||||
def get_datasets(path, tasks, tokz, ctx_max_len=100, special_token_ids=None):
|
||||
res = {}
|
||||
|
||||
for task in tasks:
|
||||
res[task] = {}
|
||||
info = TASK2INFO[task]
|
||||
if args.data_type == "intent":
|
||||
res[task]['label_train'] = LBIDDataset(
|
||||
task, tokz, os.path.join(path, info['dataset_folder'], 'label_train.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
|
||||
res[task]['unlabel_train'] = ULBIDDataset(
|
||||
task, tokz, os.path.join(path, info['dataset_folder'], 'unlabel_train.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
|
||||
if args.meantc:
|
||||
filename = 'unlabel_' + str(args.num_unlabel) + '_train.json'
|
||||
res[task]['aug_train'] = AugDataset(task, tokz, os.path.join(path, info['dataset_folder'], 'label_train.json'),
|
||||
os.path.join(path, info['dataset_folder'], filename), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
|
||||
res[task]['val'] = TASK2INFO[task]['dataset_class'](task, tokz, os.path.join(path, info['dataset_folder'], 'valid.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
|
||||
res[task]['test'] = TASK2INFO[task]['dataset_class'](task, tokz, os.path.join(path, info['dataset_folder'], 'test.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
|
||||
return res
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
(wikisql sst woz.en squad srl)
|
||||
(srl squad wikisql woz.en sst)
|
||||
(wikisql woz.en srl sst squad)
|
||||
(srl sst squad woz.en wikisql)
|
||||
(woz.en sst squad wikisql srl)
|
|
@ -0,0 +1,42 @@
|
|||
import json, os, argparse
|
||||
from settings import parse_args
|
||||
import random
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--train_dir',type=str, help='Divide part of labeled data to unlabeled.')
|
||||
parser.add_argument('--data_outdir',type=str, help='Directory of divided labeled and unlabeled data.')
|
||||
parser.add_argument('--num_label',type=int,help='Number of labeled data for each training dataset.')
|
||||
args = parser.parse_args()
|
||||
|
||||
data_outdir = args.data_outdir
|
||||
train_dir = args.train_dir
|
||||
K = args.num_label
|
||||
|
||||
train_file = os.path.join(train_dir,'ripe_data','train.json')
|
||||
train_out_label = os.path.join(data_outdir,'label_train.json')
|
||||
train_out_unlabel = os.path.join(data_outdir,'unlabel_train.json')
|
||||
|
||||
with open(train_file,'r', encoding='utf-8') as f:
|
||||
data = [json.loads(i) for i in f.readlines()]
|
||||
|
||||
label_idx_list = random.sample(range(len(data)), K)
|
||||
label_data = [data[i] for i in range(len(data)) if i in label_idx_list]
|
||||
print(K,len(label_data))
|
||||
unlabel_data = [data[i] for i in range(len(data)) if i not in label_idx_list]
|
||||
|
||||
with open(train_out_label, 'w', encoding='utf-8') as f:
|
||||
for i in label_data:
|
||||
print(json.dumps(i, ensure_ascii=False), file=f)
|
||||
|
||||
with open(train_out_unlabel, 'w', encoding='utf-8') as f:
|
||||
for i in unlabel_data:
|
||||
print(json.dumps(i, ensure_ascii=False), file=f)
|
||||
|
||||
str_list = ['100','500','2000']
|
||||
for i in str_list:
|
||||
random.shuffle(unlabel_data)
|
||||
with open(os.path.join(data_outdir, 'unlabel_'+i+'_train.json'), 'w', encoding='utf-8') as f:
|
||||
for j in unlabel_data[:int(i)]:
|
||||
print(json.dumps(j, ensure_ascii=False), file=f)
|
||||
|
||||
print('Finish dividing',args.train_dir)
|
|
@ -0,0 +1,246 @@
|
|||
# Easy data augmentation techniques for text classification
|
||||
# Jason Wei and Kai Zou
|
||||
import nltk
|
||||
# nltk.download('wordnet', '/Users/yingxiu/Desktop/Xiu_CODE/Xiu/SCL/nltk_data')
|
||||
# nltk.download('omw-1.4')
|
||||
# from nltk_data.corpora import wordnet
|
||||
# from nltk.corpus import wordnet
|
||||
nltk.data.path.append('/Users/yingxiu/Desktop/Xiu_CODE/Xiu/SCL/nltk_data')
|
||||
import re
|
||||
import random
|
||||
from random import shuffle
|
||||
random.seed(1)
|
||||
|
||||
#stop words list
|
||||
stop_words = ['i', 'me', 'my', 'myself', 'we', 'our',
|
||||
'ours', 'ourselves', 'you', 'your', 'yours',
|
||||
'yourself', 'yourselves', 'he', 'him', 'his',
|
||||
'himself', 'she', 'her', 'hers', 'herself',
|
||||
'it', 'its', 'itself', 'they', 'them', 'their',
|
||||
'theirs', 'themselves', 'what', 'which', 'who',
|
||||
'whom', 'this', 'that', 'these', 'those', 'am',
|
||||
'is', 'are', 'was', 'were', 'be', 'been', 'being',
|
||||
'have', 'has', 'had', 'having', 'do', 'does', 'did',
|
||||
'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or',
|
||||
'because', 'as', 'until', 'while', 'of', 'at',
|
||||
'by', 'for', 'with', 'about', 'against', 'between',
|
||||
'into', 'through', 'during', 'before', 'after',
|
||||
'above', 'below', 'to', 'from', 'up', 'down', 'in',
|
||||
'out', 'on', 'off', 'over', 'under', 'again',
|
||||
'further', 'then', 'once', 'here', 'there', 'when',
|
||||
'where', 'why', 'how', 'all', 'any', 'both', 'each',
|
||||
'few', 'more', 'most', 'other', 'some', 'such', 'no',
|
||||
'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too',
|
||||
'very', 's', 't', 'can', 'will', 'just', 'don',
|
||||
'should', 'now', '']
|
||||
|
||||
#cleaning up text
|
||||
|
||||
def get_only_chars(line):
|
||||
clean_line = ""
|
||||
|
||||
line = line.replace("’", "")
|
||||
line = line.replace("'", "")
|
||||
line = line.replace("-", " ") # replace hyphens with spaces
|
||||
line = line.replace("\t", " ")
|
||||
line = line.replace("\n", " ")
|
||||
line = line.lower()
|
||||
|
||||
for char in line:
|
||||
if char in 'qwertyuiopasdfghjklzxcvbnm ':
|
||||
clean_line += char
|
||||
else:
|
||||
clean_line += ' '
|
||||
|
||||
clean_line = re.sub(' +', ' ', clean_line) # delete extra spaces
|
||||
if clean_line[0] == ' ':
|
||||
clean_line = clean_line[1:]
|
||||
return clean_line
|
||||
|
||||
########################################################################
|
||||
# *Synonym replacement
|
||||
# Replace n words in the sentence with synonyms from wordnet
|
||||
########################################################################
|
||||
|
||||
|
||||
#for the first time you use wordnet
|
||||
# import nltk
|
||||
# nltk.download('wordnet')
|
||||
|
||||
def synonym_replacement(words, n):
|
||||
new_words = words.copy()
|
||||
random_word_list = list(
|
||||
set([word for word in words if word not in stop_words]))
|
||||
random.shuffle(random_word_list)
|
||||
num_replaced = 0
|
||||
for random_word in random_word_list:
|
||||
synonyms = get_synonyms(random_word)
|
||||
if len(synonyms) >= 1:
|
||||
synonym = random.choice(list(synonyms))
|
||||
new_words = [synonym if word == random_word else word for word in new_words]
|
||||
#print("replaced", random_word, "with", synonym)
|
||||
num_replaced += 1
|
||||
if num_replaced >= n: # only replace up to n words
|
||||
break
|
||||
|
||||
#this is stupid but we need it, trust me
|
||||
sentence = ' '.join(new_words)
|
||||
new_words = sentence.split(' ')
|
||||
|
||||
return new_words
|
||||
|
||||
|
||||
def get_synonyms(word):
|
||||
synonyms = set()
|
||||
for syn in wordnet.synsets(word):
|
||||
for l in syn.lemmas():
|
||||
synonym = l.name().replace("_", " ").replace("-", " ").lower()
|
||||
synonym = "".join(
|
||||
[char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm'])
|
||||
synonyms.add(synonym)
|
||||
if word in synonyms:
|
||||
synonyms.remove(word)
|
||||
return list(synonyms)
|
||||
|
||||
########################################################################
|
||||
# *Random deletion
|
||||
# Randomly delete words from the sentence with probability p
|
||||
########################################################################
|
||||
|
||||
|
||||
def random_deletion(words, p):
|
||||
|
||||
#obviously, if there's only one word, don't delete it
|
||||
if len(words) == 1:
|
||||
return words
|
||||
|
||||
#randomly delete words with probability p
|
||||
new_words = []
|
||||
for word in words:
|
||||
r = random.uniform(0, 1)
|
||||
if r > p:
|
||||
new_words.append(word)
|
||||
|
||||
#if you end up deleting all words, just return a random word
|
||||
if len(new_words) == 0 and len(words)>1:
|
||||
rand_int = random.randint(0, len(words)-1)
|
||||
return [words[rand_int]]
|
||||
|
||||
return new_words
|
||||
|
||||
########################################################################
|
||||
# *Random swap
|
||||
# Randomly swap two words in the sentence n times
|
||||
########################################################################
|
||||
|
||||
|
||||
def random_swap(words, n):
|
||||
new_words = words.copy()
|
||||
for _ in range(n):
|
||||
if len(new_words)>1:
|
||||
new_words = swap_word(new_words)
|
||||
else:
|
||||
new_words = words.copy()
|
||||
return new_words
|
||||
|
||||
|
||||
def swap_word(new_words):
|
||||
random_idx_1 = random.randint(0, len(new_words)-1)
|
||||
random_idx_2 = random_idx_1
|
||||
counter = 0
|
||||
while random_idx_2 == random_idx_1:
|
||||
random_idx_2 = random.randint(0, len(new_words)-1)
|
||||
counter += 1
|
||||
if counter > 3:
|
||||
return new_words
|
||||
new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
|
||||
return new_words
|
||||
|
||||
########################################################################
|
||||
# *Random insertion
|
||||
# Randomly insert n words into the sentence
|
||||
########################################################################
|
||||
|
||||
|
||||
def random_insertion(words, n):
|
||||
new_words = words.copy()
|
||||
for _ in range(n):
|
||||
add_word(new_words)
|
||||
return new_words
|
||||
|
||||
|
||||
def add_word(new_words):
|
||||
synonyms = []
|
||||
counter = 0
|
||||
while len(synonyms) < 1:
|
||||
if len(new_words)>1:
|
||||
random_word = new_words[random.randint(0, len(new_words)-1)]
|
||||
synonyms = get_synonyms(random_word)
|
||||
counter += 1
|
||||
if counter >= 10:
|
||||
return
|
||||
random_synonym = synonyms[0]
|
||||
random_idx = random.randint(0, len(new_words)-1)
|
||||
new_words.insert(random_idx, random_synonym)
|
||||
|
||||
########################################################################
|
||||
# main data augmentation function
|
||||
########################################################################
|
||||
|
||||
|
||||
def eda(sentence, alpha_sr=0.0, alpha_ri=0.0, alpha_rs=0.1, p_rd=0.05, num_aug=5):
|
||||
|
||||
sentence = get_only_chars(sentence)
|
||||
words = sentence.split(' ')
|
||||
words = [word for word in words if word!='']
|
||||
num_words = len(words)
|
||||
|
||||
augmented_sentences = []
|
||||
num_new_per_technique = int(num_aug/4)+1
|
||||
|
||||
#sr
|
||||
if (alpha_sr > 0):
|
||||
n_sr = max(1, int(alpha_sr*num_words))
|
||||
for _ in range(num_new_per_technique):
|
||||
a_words = synonym_replacement(words, n_sr)
|
||||
augmented_sentences.append(' '.join(a_words))
|
||||
|
||||
#ri
|
||||
if (alpha_ri > 0):
|
||||
n_ri = max(1, int(alpha_ri*num_words))
|
||||
for _ in range(num_new_per_technique):
|
||||
a_words = random_insertion(words, n_ri)
|
||||
augmented_sentences.append(' '.join(a_words))
|
||||
|
||||
#rs
|
||||
if (alpha_rs > 0):
|
||||
n_rs = max(1, int(alpha_rs*num_words))
|
||||
for _ in range(num_new_per_technique):
|
||||
a_words = random_swap(words, n_rs)
|
||||
augmented_sentences.append(' '.join(a_words))
|
||||
|
||||
#rd
|
||||
if (p_rd > 0):
|
||||
for _ in range(num_new_per_technique):
|
||||
a_words = random_deletion(words, p_rd)
|
||||
augmented_sentences.append(' '.join(a_words))
|
||||
|
||||
augmented_sentences = [get_only_chars(sentence)
|
||||
for sentence in augmented_sentences]
|
||||
shuffle(augmented_sentences)
|
||||
#trim so that we have the desired number of augmented sentences
|
||||
if num_aug >= 1:
|
||||
augmented_sentences = augmented_sentences[:num_aug]
|
||||
else:
|
||||
keep_prob = num_aug / len(augmented_sentences)
|
||||
augmented_sentences = [
|
||||
s for s in augmented_sentences if random.uniform(0, 1) < keep_prob]
|
||||
#append the original sentence
|
||||
augmented_sentences.append(sentence)
|
||||
shuffle(augmented_sentences)
|
||||
return augmented_sentences
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
s = 'I like the cat with white and black fur.'
|
||||
print(eda(s))
|
|
@ -0,0 +1,94 @@
|
|||
import torch
|
||||
import csv
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from fp16 import FP16_Module
|
||||
import GPUtil
|
||||
from collections import OrderedDict
|
||||
from settings import args, MODEL_CLASS, TOKENIZER, SPECIAL_TOKEN_IDS, init_logging
|
||||
from settings import MEMORY_FACTOR, LEN_FACTOR, TASK_DICT, MODEL_CONFIG, DATA_ATTRS, SPECIAL_TOKENS, CONFIG_CLASS, CONFIG_NAME
|
||||
from utils import QADataset, top_k_top_p_filtering, create_dataloader, logits_to_tokens, get_model_dir
|
||||
from utils import sample_sequence, remove_id, get_gen_token, lll_unbound_setting
|
||||
from metrics import *
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_one_to_one(curr_task, task_eval, model, score_dict,last=True, eval_data=None, out_dir=None):
|
||||
logger.info('Start to evaluate %s'%task_eval)
|
||||
max_ans_len = eval_data.max_ans_len + 1
|
||||
# print('Maximum answer length',max_ans_len,flush=True)
|
||||
|
||||
if out_dir is not None:
|
||||
if not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir)
|
||||
out_dir = os.path.join(
|
||||
out_dir, f'{task_eval}.json')
|
||||
out_dir_content = []
|
||||
|
||||
with torch.no_grad():
|
||||
model.set_active_adapters(task_eval)
|
||||
# print('model of task: {}'.format(task),list(model.parameters()), flush=True)
|
||||
model.eval()
|
||||
pred_ans_all, gold_ans_all = [], []
|
||||
context_all = []
|
||||
all_pred_task_name = []
|
||||
|
||||
for i, data in enumerate(eval_data):
|
||||
example = data['context_id'].cuda()
|
||||
example_mask = data['context_mask'].cuda()
|
||||
if out_dir is not None:
|
||||
context_all.append(example)
|
||||
gold_ans = data['ans_id'].cuda()
|
||||
|
||||
# Print model inputs to check.
|
||||
# if i==0:
|
||||
# print('context:',self.tokz.decode(example[0]),flush=True)
|
||||
# print('ans:',self.tokz.decode(gold_ans[0]),flush=True)
|
||||
|
||||
pred_ans, _ = get_answer(self.tokz, model, example, example_mask, self.args.max_ans_len, args=self.args, is_eval=True)
|
||||
# print('context shape',context.shape,flush=True)
|
||||
# pred_ans = pred_ans[:,context.shape[1]:]
|
||||
pred_ans = pred_ans[:,1:] # Remove the first <pad> token.
|
||||
# print('pred',pred_ans,flush=True)
|
||||
# print('gold',gold_ans,flush=True) # output tensors
|
||||
|
||||
pred_ans_all.append(pred_ans)
|
||||
gold_ans_all.append(gold_ans)
|
||||
|
||||
|
||||
# * Compute score.
|
||||
# print('task_eval',task_eval,flush=True)
|
||||
get_test_score(task_eval, qa_results, score_dict)
|
||||
|
||||
return model, score_dict
|
||||
|
||||
def get_test_score(task_eval,qa_results,score_dict):
|
||||
# score = ours_compute_metrics(task_eval, qa_results)
|
||||
# score = compute_metrics(task_eval, qa_results)
|
||||
score = compute_metrics(
|
||||
qa_results,
|
||||
bleu='iwslt.en.de' in task_eval or 'multinli.in.out' in task_eval,
|
||||
dialogue='woz.en' in task_eval,
|
||||
rouge='cnn_dailymail' in task_eval,
|
||||
logical_form='wikisql' in task_eval,
|
||||
corpus_f1='zre' in task_eval
|
||||
)
|
||||
score_dict[task_eval] = score
|
||||
|
||||
def training_test_one_to_many(model,curr_task,step,last=True, dataloaders=None, args=None): # task_load like 'hwu' 'banking'
|
||||
model.eval()
|
||||
score_dicts = []
|
||||
logger.info("task: {}, step: {}".format(curr_task, step))
|
||||
score_dict = {k:None for k in args.tasks}
|
||||
with torch.no_grad():
|
||||
for task_eval in args.tasks:
|
||||
if args.test_overfit:
|
||||
eval_data = dataloaders[task]['label_train']
|
||||
else:
|
||||
eval_data = dataloaders[task]['test']
|
||||
eval_dir = os.path.join(args.output_dir, curr_task)
|
||||
test_one_to_one(curr_task, task_eval, model, score_dict, last=last, eval_data=eval_data, out_dir=eval_dir)
|
||||
logger.info("score: {}".format(score_dict))
|
||||
score_dicts.append(score_dict)
|
||||
|
||||
return score_dicts[0]
|
|
@ -0,0 +1,88 @@
|
|||
import json
|
||||
import os
|
||||
import argparse
|
||||
# from settings import parse_args
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--tasks',nargs='+', default=['ag'])
|
||||
parser.add_argument('--output_dir',type=str)
|
||||
parser.add_argument('--res_dir',type=str, default='scores')
|
||||
parser.add_argument('--exp',type=str)
|
||||
args = parser. parse_args()
|
||||
|
||||
|
||||
# TODO: Calculate average scores and average forgetting. =================================================================================
|
||||
test_res_path = os.path.join(args.res_dir, args.exp+'_res.txt')
|
||||
test_res = open(test_res_path, 'a')
|
||||
# * Load saved scores.
|
||||
score_all_time=[]
|
||||
with open(os.path.join(args.output_dir, "metrics.json"),"r") as f:
|
||||
score_dict = [json.loads(row) for row in f.readlines()]
|
||||
|
||||
# print(f'SCORE DICT {score_dict}',flush=True)
|
||||
|
||||
# * Calculate Average Score of all tasks observed.
|
||||
# score_last_time = score_dict[-1] # last time and last epoch
|
||||
# average_score = score_last_time['Average']
|
||||
# print(f'Last time scores for all tasks:\n {score_last_time}', file=test_res)
|
||||
# print('Last time scores for all tasks:\n',score_last_time, flush=True)
|
||||
# print(f"Average score is {average_score['score']}; teacher average score is {average_score['tc_score']}", file=test_res)
|
||||
# print(f"Average score is {average_score['score']}; teacher average score is {average_score['tc_score']}", flush=True)
|
||||
|
||||
# * Calculate the max scores of all tasks
|
||||
tasks_max_scores = {}
|
||||
for task in args.tasks:
|
||||
for row in score_dict:
|
||||
key_name = list(row.keys())[0]
|
||||
if task+'_finish' in key_name:
|
||||
if task not in tasks_max_scores:
|
||||
# print(row[task+'_finish'])
|
||||
tasks_max_scores[task] = max(row[task+'_finish']['score'],row[task+'_finish']['tc_score'])
|
||||
if 'prev_tasks_scores' in key_name:
|
||||
for term in row['prev_tasks_scores']:
|
||||
if task in term:
|
||||
new_task_score = max(term[task]['score'], term[task]['tc_score'])
|
||||
if new_task_score > tasks_max_scores[task]:
|
||||
tasks_max_scores[task] = new_task_score
|
||||
print(f'Tasks max scores are {tasks_max_scores}')
|
||||
print(f'Average score is {sum(tasks_max_scores.values())/len(tasks_max_scores.values())}', file=test_res)
|
||||
print(f'Average score is {sum(tasks_max_scores.values())/len(tasks_max_scores.values())}')
|
||||
|
||||
# * Calculate Backward transfer.
|
||||
each_finish_score = []
|
||||
each_finish_tc_score = []
|
||||
last_task_name = args.tasks[-1]
|
||||
for row in score_dict:
|
||||
key_name = list(row.keys())[0]
|
||||
if '_finish' in key_name and last_task_name not in key_name:
|
||||
# print(f'ROW: {row}',flush=True)
|
||||
for k,v in row.items():
|
||||
# print(v, flush=True)
|
||||
each_finish_score.append(v['score'])
|
||||
each_finish_tc_score.append(v['tc_score'])
|
||||
average_wo_last = score_dict[-2]
|
||||
# bwt = - min(sum(each_finish_score) / len(each_finish_score), sum(each_finish_tc_score) / len(each_finish_tc_score)) + max(average_wo_last['Average_wo_curr_task']['score'], average_wo_last['Average_wo_curr_task']['tc_score'])
|
||||
# bwt = - min(sum(each_finish_score) / len(each_finish_score), sum(each_finish_tc_score) / len(each_finish_tc_score)) + min(average_wo_last['Average_wo_curr_task']['score'], average_wo_last['Average_wo_curr_task']['tc_score'])
|
||||
# tc_bwt = - sum(each_finish_tc_score) / len(each_finish_tc_score) + max(average_wo_last['Average_wo_curr_task']['score'], average_wo_last['Average_wo_curr_task']['tc_score'])
|
||||
prev_max_scores = [tasks_max_scores[task] for task in args.tasks[:-1]]
|
||||
print(f'Each finish score {each_finish_score}', flush=True)
|
||||
bwt = (- min(sum(each_finish_score), sum(each_finish_tc_score)) + sum(prev_max_scores)) / len(args.tasks[:-1])
|
||||
print(f"Backward transfer is {bwt}", file=test_res)
|
||||
print(f"Backward transfer is {bwt}", flush=True)
|
||||
|
||||
# # * Calculate Forward transfer
|
||||
# each_initial_score = []
|
||||
# # each_initial_tc_score = []
|
||||
# for row in score_dict:
|
||||
# if '_initial' in list(row.keys())[0]:
|
||||
# for k,v in row.items():
|
||||
# each_initial_score.append(v['score'])
|
||||
# # each_initial_tc_score.append(v['tc_score'])
|
||||
# fwt = sum(each_initial_score) / len(each_initial_score)
|
||||
# # tc_fwt = sum(each_initial_tc_score) / len(each_initial_tc_score)
|
||||
|
||||
# # * Print out final results of scores.
|
||||
# print(f'Forward transfer is {fwt}', file=test_res)
|
||||
# print(f'Forward transfer is {fwt}', flush=True)
|
||||
|
||||
# print(average_score['tc_score'], tc_bwt,file=test_res)
|
||||
|
|
@ -0,0 +1,434 @@
|
|||
import torch
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import numpy as np
|
||||
from settings import parse_args
|
||||
from eda import *
|
||||
# from train import special_tokens, special_token_ids
|
||||
|
||||
args = parse_args()
|
||||
class LBIDDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, task_name, tokz, data_path, ctx_max_len=100, special_token_ids=None):
|
||||
self.tokz = tokz
|
||||
self.data_path = data_path
|
||||
self.max_ans_len = 10
|
||||
self.special_token_ids = special_token_ids
|
||||
|
||||
with open(data_path, "r", encoding='utf-8') as f:
|
||||
ori_data = [json.loads(i) for i in f.readlines()]
|
||||
data = []
|
||||
for i in ori_data:
|
||||
data += self.parse_example(i)
|
||||
|
||||
self.data = []
|
||||
if len(data) > 0:
|
||||
self.data = self.data_tokenization(task_name, data)
|
||||
|
||||
def get_answer(self, target): # get answer text from intent
|
||||
# replace - and _ to make the text seems more natural
|
||||
ans = target.replace("-", " ").replace("_", " ")
|
||||
return ans
|
||||
|
||||
def parse_example(self, example):
|
||||
text = example['input'].replace('-',' ').replace('_',' ').replace('\\',' ')
|
||||
ans = self.get_answer(example['output'])
|
||||
num_aug = 0 # Number of augmented sentences per sample.
|
||||
if not args.meantc or num_aug == 0:
|
||||
return [(text, ans)]
|
||||
else:
|
||||
aug_text_list = eda(text, num_aug=num_aug)
|
||||
res = [(aug_text, ans) for aug_text in aug_text_list]
|
||||
res.append((text, ans))
|
||||
return res
|
||||
|
||||
def per_tokenization(self, task_name, d):
|
||||
# print(d,flush=True)
|
||||
input_text, ans_text = d
|
||||
|
||||
if args.use_task_pmt:
|
||||
# prefix_id = self.tokz.encode(task_name+':')
|
||||
input_sen = task_name+':' + input_text
|
||||
else:
|
||||
# prefix_id = self.tokz.encode('TASK:')
|
||||
input_sen = 'TASK:' + input_text
|
||||
|
||||
context_sen = input_sen + '<ANS>'
|
||||
all_sen = input_text + '<ANS>' + ans_text
|
||||
ans_sen = ans_text
|
||||
# ans_id = self.tokz.encode(ans_text)
|
||||
# self.max_ans_len = max(self.max_ans_len, len(ans_id) + 1)
|
||||
# return {
|
||||
# 'all_id': self.tokz.encode(all_sen),
|
||||
# 'input_id': self.tokz.encode(input_sen),
|
||||
# 'context_id': self.tokz.encode(context_sen),
|
||||
# 'ans_id': ans_id
|
||||
# }
|
||||
return {
|
||||
'all_sen': all_sen,
|
||||
'input_sen': input_sen,
|
||||
'context_sen': context_sen,
|
||||
'ans_sen': ans_sen,
|
||||
}
|
||||
|
||||
def data_tokenization(self, task_name, data):
|
||||
# print('data',data[:10],flush=True)
|
||||
data = [self.per_tokenization(task_name, i) for i in data]
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
|
||||
class ULBIDDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, task_name, tokz, data_path, ctx_max_len=100, special_token_ids=None):
|
||||
self.tokz = tokz
|
||||
self.data_path = data_path
|
||||
self.max_ans_len = 0
|
||||
self.special_token_ids = special_token_ids
|
||||
|
||||
with open(data_path, "r", encoding='utf-8') as f:
|
||||
data = [json.loads(i) for i in f.readlines()]
|
||||
data = [self.parse_example(i) for i in data] # [utter, label]
|
||||
|
||||
self.data = []
|
||||
if len(data) > 0:
|
||||
self.data = self.data_tokenization(task_name, data)
|
||||
|
||||
def get_answer(self, target): # get answer text from intent
|
||||
# replace - and _ to make the text seems more natural
|
||||
ans = target.replace("-", " ").replace("_", " ")
|
||||
return ans
|
||||
|
||||
def parse_example(self, example):
|
||||
text = example['input'].replace('-',' ').replace('_',' ').replace('\\',' ')
|
||||
ans = self.get_answer(example['output'])
|
||||
return text, ans
|
||||
|
||||
def per_tokenization(self, task_name, d):
|
||||
input_text, ans_text = d
|
||||
if args.use_task_pmt:
|
||||
# prefix_id = self.tokz.encode(task_name+':')
|
||||
input_sen = task_name+':' + input_text
|
||||
else:
|
||||
# prefix_id = self.tokz.encode('TASK:')
|
||||
input_sen = 'TASK:' + input_text
|
||||
|
||||
context_sen = input_sen + '<ANS>'
|
||||
# all_sen = context_sen + ans_text
|
||||
all_sen = input_text + '<ANS>' + ans_text
|
||||
ans_sen = ans_text
|
||||
# ans_id = self.tokz.encode(ans_text)
|
||||
# self.max_ans_len = max(self.max_ans_len, len(ans_id) + 1)
|
||||
# return {
|
||||
# 'all_id': self.tokz.encode(all_sen),
|
||||
# 'input_id': self.tokz.encode(input_sen),
|
||||
# 'context_id': self.tokz.encode(context_sen),
|
||||
# 'ans_id': ans_id
|
||||
# }
|
||||
return {
|
||||
'all_sen': all_sen,
|
||||
'input_sen': input_sen,
|
||||
'context_sen': context_sen,
|
||||
'ans_sen': ans_sen,
|
||||
}
|
||||
|
||||
|
||||
def data_tokenization(self, task_name, data):
|
||||
data = [self.per_tokenization(task_name, i) for i in data]
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
|
||||
class PinnedBatch:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __getitem__(self, k):
|
||||
return self.data[k]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.data[key] = value
|
||||
|
||||
def pin_memory(self):
|
||||
for k in self.data.keys():
|
||||
self.data[k] = self.data[k].pin_memory()
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
return self.data.__repr__()
|
||||
|
||||
def __str__(self):
|
||||
return self.data.__str__()
|
||||
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
|
||||
def pad_seq(seq, pad, max_len, pad_left=False):
|
||||
if pad_left:
|
||||
return [pad] * (max_len - len(seq)) + seq
|
||||
else:
|
||||
return seq + [pad] * (max_len - len(seq))
|
||||
|
||||
|
||||
class OldPadBatchSeq:
|
||||
def __init__(self, pad_id=0):
|
||||
self.pad_id = pad_id
|
||||
|
||||
def __call__(self, batch):
|
||||
input_id = [i['input_id'] for i in batch]
|
||||
context_id = [i['context_id'] for i in batch]
|
||||
all_id = [i['all_id'] for i in batch]
|
||||
ans_id = [i['ans_id'] for i in batch]
|
||||
input_lens = [len(i) for i in input_id]
|
||||
context_lens = [len(i) for i in context_id]
|
||||
all_lens = [len(i) for i in all_id]
|
||||
ans_lens = [len(i) for i in ans_id]
|
||||
|
||||
input_mask = torch.ByteTensor([[1] * input_lens[i] + [0] * (max(input_lens)-input_lens[i]) for i in range(len(input_id))])
|
||||
all_mask = torch.ByteTensor([[1] * all_lens[i] + [0] * (max(all_lens)-all_lens[i]) for i in range(len(all_id))])
|
||||
all_label_mask = torch.ByteTensor([[0] * (context_lens[i]) + [1] * (all_lens[i] - context_lens[i]) + [0] * (max(all_lens)-all_lens[i]) for i in range(len(all_id))]) # Only calculate losses on answer tokens.
|
||||
|
||||
res = {}
|
||||
res["input_id"] = torch.tensor([pad_seq(i, self.pad_id, max(input_lens)) for i in input_id], dtype=torch.long)
|
||||
res["context_id"] = torch.tensor([pad_seq(i, self.pad_id, max(context_lens), pad_left=True) for i in context_id], dtype=torch.long)
|
||||
res["all_id"] = torch.tensor([pad_seq(i, self.pad_id, max(all_lens)) for i in all_id], dtype=torch.long)
|
||||
res["ans_id"] = torch.tensor([pad_seq(i, self.pad_id, max(ans_lens)) for i in ans_id], dtype=torch.long)
|
||||
|
||||
res["input_lens"] = torch.tensor(input_lens, dtype=torch.long)
|
||||
res["context_lens"] = torch.tensor(context_lens, dtype=torch.long)
|
||||
res["all_lens"] = torch.tensor(all_lens, dtype=torch.long)
|
||||
res["ans_lens"] = torch.tensor(ans_lens, dtype=torch.long)
|
||||
|
||||
res['input_mask'] = input_mask
|
||||
res['all_mask'] = all_mask
|
||||
res['all_label_mask'] = all_label_mask
|
||||
|
||||
return PinnedBatch(res)
|
||||
|
||||
|
||||
class PadBatchSeq:
|
||||
def __init__(self, pad_id=0, tokz=None):
|
||||
self.pad_id = pad_id
|
||||
self.tokz = tokz
|
||||
|
||||
def __call__(self, batch):
|
||||
input_sen = [i['input_sen'] for i in batch]
|
||||
context_sen = [i['context_sen'] for i in batch]
|
||||
all_sen = [i['all_sen'] for i in batch]
|
||||
ans_sen = [i['ans_sen'] for i in batch]
|
||||
|
||||
input_encoding = self.tokz(input_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
|
||||
all_encoding = self.tokz(all_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
|
||||
context_encoding = self.tokz(context_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
|
||||
ans_encoding = self.tokz(ans_sen, padding='longest', max_length=args.max_ans_len, truncation=True, return_tensors='pt')
|
||||
|
||||
res = {}
|
||||
res["input_id"], res['input_mask'] = input_encoding.input_ids, input_encoding.attention_mask
|
||||
res["all_id"], res['all_mask'] = all_encoding.input_ids, all_encoding.attention_mask
|
||||
res["context_id"], res['context_mask'] = context_encoding.input_ids, context_encoding.attention_mask
|
||||
res["ans_id"], res['ans_mask'] = ans_encoding.input_ids, ans_encoding.attention_mask
|
||||
|
||||
return PinnedBatch(res)
|
||||
|
||||
|
||||
# * Information for different tasks.
|
||||
if args.data_type == 'intent':
|
||||
TASK2INFO = {
|
||||
"banking": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "banking",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"clinc": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "clinc",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"hwu": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "hwu",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"atis": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "atis",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"tod": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "FB_TOD_SF",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"snips": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "snips",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"top_split1": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "top_split1",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"top_split2": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "top_split2",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"top_split3": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "top_split3",
|
||||
"task_type": "CLS",
|
||||
}
|
||||
}
|
||||
elif args.data_type =='slot':
|
||||
TASK2INFO = {
|
||||
"dstc8": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "dstc8",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"restaurant8": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "restaurant8",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"atis": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "atis",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"mit_movie_eng": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "mit_movie_eng",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"mit_movie_trivia": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "mit_movie_trivia",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"mit_restaurant": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "mit_restaurant",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"snips": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "snips",
|
||||
"task_type": "SlotTagging",
|
||||
}
|
||||
}
|
||||
elif args.data_type =='tc':
|
||||
TASK2INFO = {
|
||||
"ag": {
|
||||
"dataset_class": LBIDDataset,
|
||||
"dataset_folder": "ag",
|
||||
"task_type": "tc",
|
||||
},
|
||||
'dbpedia':{
|
||||
'dataset_class':LBIDDataset,
|
||||
'dataset_folder':'dbpedia',
|
||||
'task_type':'tc',
|
||||
},
|
||||
'yelp':{
|
||||
'dataset_class':LBIDDataset,
|
||||
'dataset_folder':'yelp',
|
||||
'task_type':'tc',
|
||||
},
|
||||
'yahoo':{
|
||||
'dataset_class':LBIDDataset,
|
||||
'dataset_folder':'yahoo',
|
||||
'task_type':'tc',
|
||||
},
|
||||
'snips':{
|
||||
'dataset_class':LBIDDataset,
|
||||
'dataset_folder':'snips',
|
||||
'task_type':'tc',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_unlabel_data(path, task):
|
||||
info = TASK2INFO[task]
|
||||
data_path = os.path.join(path, info['dataset_folder'],'unlabel_train.json')
|
||||
return data_path
|
||||
|
||||
def get_unlabel_dict(path, task, tokz, args):
|
||||
info = TASK2INFO[task]
|
||||
special_token_ids = {'ans_token':tokz.convert_tokens_to_ids('ANS:')}
|
||||
if args.data_type == 'intent':
|
||||
return LBIDDataset(task, tokz, path, args.ctx_max_len, special_token_ids=special_token_ids)
|
||||
else:
|
||||
return None
|
||||
|
||||
class MixedDataset(LBIDDataset):
|
||||
def __init__(self, unlabel_data, label_data, tokz, ctx_max_len=100):
|
||||
self.ctx_max_len = ctx_max_len
|
||||
self.tokz = tokz
|
||||
self.max_ans_len = 0
|
||||
self.special_token_ids = {'ans_token':tokz.convert_tokens_to_ids('ANS:')}
|
||||
|
||||
self.data = []
|
||||
self.data += unlabel_data
|
||||
self.data += label_data
|
||||
|
||||
class AugDataset(LBIDDataset):
|
||||
def __init__(self, task_name, tokz, label_data_path, unlabel_data_path, ctx_max_len=100, special_token_ids=None):
|
||||
self.tokz = tokz
|
||||
self.max_ans_len = 0
|
||||
self.special_token_ids = special_token_ids
|
||||
with open(label_data_path, "r", encoding='utf-8') as f:
|
||||
ori_data = [json.loads(i) for i in f.readlines()]
|
||||
with open(unlabel_data_path, "r", encoding='utf-8') as f:
|
||||
ori_data += [json.loads(i) for i in f.readlines()]
|
||||
data = []
|
||||
for i in ori_data:
|
||||
data += self.parse_example(i)
|
||||
|
||||
self.data = []
|
||||
if len(data) > 0:
|
||||
self.data = self.data_tokenization(task_name, data)
|
||||
|
||||
def parse_example(self, example):
|
||||
text = example['userInput']['text']
|
||||
aug_text_list = eda(text, num_aug=args.num_aug)
|
||||
ans = self.get_answer(example['intent'])
|
||||
res = [(aug_text, ans) for aug_text in aug_text_list]
|
||||
res.append((text, ans))
|
||||
return res
|
||||
|
||||
def get_datasets(path, tasks, tokz, ctx_max_len=100, special_token_ids=None):
|
||||
res = {}
|
||||
|
||||
for task in tasks:
|
||||
res[task] = {}
|
||||
info = TASK2INFO[task]
|
||||
|
||||
if args.data_type == "intent" or args.data_type == "tc":
|
||||
res[task]['label_train'] = LBIDDataset(
|
||||
task, tokz, os.path.join(path, info['dataset_folder'], 'label_train.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
|
||||
if args.use_unlabel:
|
||||
res[task]['unlabel_train'] = ULBIDDataset(
|
||||
task, tokz, os.path.join(path, info['dataset_folder'], 'unlabel_train.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
|
||||
if args.meantc:
|
||||
filename = 'unlabel_' + str(args.num_unlabel) + '_train.json'
|
||||
res[task]['aug_train'] = AugDataset(task, tokz, os.path.join(path, info['dataset_folder'], 'label_train.json'),
|
||||
os.path.join(path, info['dataset_folder'], filename), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
|
||||
res[task]['val'] = TASK2INFO[task]['dataset_class'](task, tokz, os.path.join(path, info['dataset_folder'], 'test.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
|
||||
res[task]['test'] = TASK2INFO[task]['dataset_class'](task, tokz, os.path.join(path, info['dataset_folder'], 'test.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
|
||||
return res
|
||||
|
|
@ -0,0 +1,374 @@
|
|||
import collections
|
||||
import string
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
AGG_OPS = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
|
||||
COND_OPS = ['=', '>', '<', 'OP']
|
||||
COND_OPS = ['=', '>', '<']
|
||||
|
||||
|
||||
def lcsstr(string1, string2):
|
||||
answer = 0
|
||||
len1, len2 = len(string1), len(string2)
|
||||
for i in range(len1):
|
||||
match = 0
|
||||
for j in range(len2):
|
||||
if (i + j < len1 and string1[i + j] == string2[j]):
|
||||
match += 1
|
||||
if match > answer:
|
||||
answer = match
|
||||
else:
|
||||
match = 0
|
||||
return answer
|
||||
|
||||
|
||||
def lcs(X, Y):
|
||||
m = len(X)
|
||||
n = len(Y)
|
||||
L = [[None]*(n + 1) for i in range(m + 1)]
|
||||
|
||||
for i in range(m + 1):
|
||||
for j in range(n + 1):
|
||||
if i == 0 or j == 0 :
|
||||
L[i][j] = 0
|
||||
elif X[i-1] == Y[j-1]:
|
||||
L[i][j] = L[i-1][j-1]+1
|
||||
else:
|
||||
L[i][j] = max(L[i-1][j], L[i][j-1])
|
||||
return L[m][n] / max(m, n) + lcsstr(X, Y) / 1e4 - min(m, n) / 1e8
|
||||
|
||||
|
||||
def normalize_text(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
def remove_articles(text):
|
||||
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def to_delta_state(line):
|
||||
delta_state = {'inform': {}, 'request': {}}
|
||||
try:
|
||||
if line.lower() == 'none' or line.strip() == '' or line.strip() == ';':
|
||||
return delta_state
|
||||
inform, request = [[y.strip() for y in x.strip().split(',')] for x in line.split(';')]
|
||||
inform_pairs = {}
|
||||
for i in inform:
|
||||
try:
|
||||
k, v = i.split(':')
|
||||
inform_pairs[k.strip()] = v.strip()
|
||||
except:
|
||||
pass
|
||||
delta_state = {'inform': inform_pairs, 'request': request}
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
return delta_state
|
||||
|
||||
|
||||
def update_state(state, delta):
|
||||
for act, slot in delta.items():
|
||||
state[act] = slot
|
||||
return state
|
||||
|
||||
|
||||
def dict_cmp(d1, d2):
|
||||
def cmp(a, b):
|
||||
for k1, v1 in a.items():
|
||||
if k1 not in b:
|
||||
return False
|
||||
else:
|
||||
if v1 != b[k1]:
|
||||
return False
|
||||
return True
|
||||
return cmp(d1, d2) and cmp(d2, d1)
|
||||
|
||||
|
||||
def to_lf(s, table):
|
||||
aggs = [y.lower() for y in AGG_OPS]
|
||||
agg_to_idx = {x: i for i, x in enumerate(aggs)}
|
||||
conditionals = [y.lower() for y in COND_OPS]
|
||||
headers_unsorted = [(y.lower(), i) for i, y in enumerate(table['header'])]
|
||||
headers = [(y.lower(), i) for i, y in enumerate(table['header'])]
|
||||
headers.sort(reverse=True, key=lambda x: len(x[0]))
|
||||
condition_s, conds = None, []
|
||||
if 'where' in s:
|
||||
s, condition_s = s.split('where', 1)
|
||||
|
||||
s = ' '.join(s.split()[1:-2])
|
||||
s_no_agg = ' '.join(s.split()[1:])
|
||||
sel, agg = None, 0
|
||||
lcss, idxs = [], []
|
||||
for col, idx in headers:
|
||||
lcss.append(lcs(col, s))
|
||||
lcss.append(lcs(col, s_no_agg))
|
||||
idxs.append(idx)
|
||||
lcss = np.array(lcss)
|
||||
max_id = np.argmax(lcss)
|
||||
sel = idxs[max_id // 2]
|
||||
if max_id % 2 == 1: # with agg
|
||||
agg = agg_to_idx[s.split()[0]]
|
||||
|
||||
full_conditions = []
|
||||
if not condition_s is None:
|
||||
|
||||
pattern = '|'.join(COND_OPS)
|
||||
split_conds_raw = re.split(pattern, condition_s)
|
||||
split_conds_raw = [conds.strip() for conds in split_conds_raw]
|
||||
split_conds = [split_conds_raw[0]]
|
||||
for i in range(1, len(split_conds_raw)-1):
|
||||
split_conds.extend(re.split('and', split_conds_raw[i]))
|
||||
split_conds += [split_conds_raw[-1]]
|
||||
for i in range(0, len(split_conds), 2):
|
||||
cur_s = split_conds[i]
|
||||
lcss = []
|
||||
for col in headers:
|
||||
lcss.append(lcs(col[0], cur_s))
|
||||
max_id = np.argmax(np.array(lcss))
|
||||
split_conds[i] = headers[max_id][0]
|
||||
for i, m in enumerate(re.finditer(pattern, condition_s)):
|
||||
split_conds[2*i] = split_conds[2*i] + ' ' + m.group()
|
||||
split_conds = [' '.join(split_conds[2*i:2*i+2]) for i in range(len(split_conds)//2)]
|
||||
|
||||
condition_s = ' and '.join(split_conds)
|
||||
condition_s = ' ' + condition_s + ' '
|
||||
for idx, col in enumerate(headers):
|
||||
condition_s = condition_s.replace(' ' + col[0] + ' ', ' Col{} '.format(col[1]))
|
||||
condition_s = condition_s.strip()
|
||||
|
||||
for idx, col in enumerate(conditionals):
|
||||
new_s = []
|
||||
for t in condition_s.split():
|
||||
if t == col:
|
||||
new_s.append('Cond{}'.format(idx))
|
||||
else:
|
||||
new_s.append(t)
|
||||
condition_s = ' '.join(new_s)
|
||||
s = condition_s
|
||||
|
||||
conds = re.split('(Col\d+ Cond\d+)', s)
|
||||
if len(conds) == 0:
|
||||
conds = [s]
|
||||
conds = [x for x in conds if len(x.strip()) > 0]
|
||||
full_conditions = []
|
||||
for i, x in enumerate(conds):
|
||||
if i % 2 == 0:
|
||||
x = x.split()
|
||||
col_num = int(x[0].replace('Col', ''))
|
||||
opp_num = int(x[1].replace('Cond', ''))
|
||||
full_conditions.append([col_num, opp_num])
|
||||
else:
|
||||
x = x.split()
|
||||
if x[-1] == 'and':
|
||||
x = x[:-1]
|
||||
x = ' '.join(x)
|
||||
if 'Col' in x:
|
||||
new_x = []
|
||||
for t in x.split():
|
||||
if 'Col' in t:
|
||||
idx = int(t.replace('Col', ''))
|
||||
t = headers_unsorted[idx][0]
|
||||
new_x.append(t)
|
||||
x = new_x
|
||||
x = ' '.join(x)
|
||||
if 'Cond' in x:
|
||||
new_x = []
|
||||
for t in x.split():
|
||||
if 'Cond' in t:
|
||||
idx = int(t.replace('Cond', ''))
|
||||
t = conditionals[idx]
|
||||
new_x.append(t)
|
||||
x = new_x
|
||||
x = ' '.join(x)
|
||||
full_conditions[-1].append(x.replace(' ', ''))
|
||||
logical_form = {'sel': sel, 'conds': full_conditions, 'agg': agg}
|
||||
return logical_form
|
||||
|
||||
|
||||
def computeLFEM(greedy, answer):
|
||||
count = 0
|
||||
correct = 0
|
||||
text_answers = []
|
||||
for idx, (g, ex) in enumerate(zip(greedy, answer)):
|
||||
count += 1
|
||||
text_answers.append([ex['answer'].lower()])
|
||||
try:
|
||||
gt = ex['sql']
|
||||
conds = gt['conds']
|
||||
lower_conds = []
|
||||
for c in conds:
|
||||
lc = c
|
||||
lc[2] = str(lc[2]).lower().replace(' ', '')
|
||||
lower_conds.append(lc)
|
||||
gt['conds'] = lower_conds
|
||||
lf = to_lf(g, ex['table'])
|
||||
correct += lf == gt
|
||||
except Exception as e:
|
||||
continue
|
||||
return (correct / count) * 100, text_answers
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
scores_for_ground_truths = []
|
||||
for ground_truth in ground_truths:
|
||||
score = metric_fn(prediction, ground_truth)
|
||||
scores_for_ground_truths.append(score)
|
||||
return max(scores_for_ground_truths)
|
||||
|
||||
|
||||
def f1_score(prediction, ground_truth):
|
||||
prediction_tokens = prediction.split()
|
||||
ground_truth_tokens = ground_truth.split()
|
||||
common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def exact_match(prediction, ground_truth):
|
||||
# print(f'pred {prediction}, gt {ground_truth}')
|
||||
return prediction == ground_truth
|
||||
|
||||
|
||||
def computeF1(outputs, targets):
|
||||
return sum([metric_max_over_ground_truths(f1_score, o, t) for o, t in zip(outputs, targets)]) / len(outputs) * 100
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
scores_for_ground_truths = []
|
||||
for idx, ground_truth in enumerate(ground_truths):
|
||||
score = metric_fn(prediction, ground_truth)
|
||||
scores_for_ground_truths.append(score)
|
||||
return max(scores_for_ground_truths)
|
||||
|
||||
|
||||
def computeEM(outputs, targets):
|
||||
outs = [metric_max_over_ground_truths(exact_match, o, t) for o, t in zip(outputs, targets)]
|
||||
# print(f'EM outs: {outs}',flush=True)
|
||||
return sum(outs) / len(outputs) * 100
|
||||
|
||||
|
||||
def score(answer, gold):
|
||||
if len(gold) > 0:
|
||||
gold = set.union(*[simplify(g) for g in gold])
|
||||
answer = simplify(answer)
|
||||
tp, tn, sys_pos, real_pos = 0, 0, 0, 0
|
||||
if answer == gold:
|
||||
if not ('unanswerable' in gold and len(gold) == 1):
|
||||
tp += 1
|
||||
else:
|
||||
tn += 1
|
||||
if not ('unanswerable' in answer and len(answer) == 1):
|
||||
sys_pos += 1
|
||||
if not ('unanswerable' in gold and len(gold) == 1):
|
||||
real_pos += 1
|
||||
return np.array([tp, tn, sys_pos, real_pos])
|
||||
|
||||
|
||||
def simplify(answer):
|
||||
return set(''.join(c for c in t if c not in string.punctuation) for t in answer.strip().lower().split()) - {'the', 'a', 'an', 'and', ''}
|
||||
|
||||
|
||||
def computeCF1(greedy, answer):
|
||||
scores = np.zeros(4)
|
||||
for g, a in zip(greedy, answer):
|
||||
scores += score(g, a)
|
||||
tp, tn, sys_pos, real_pos = scores.tolist()
|
||||
total = len(answer)
|
||||
if tp == 0:
|
||||
p = r = f = 0.0
|
||||
else:
|
||||
p = tp / float(sys_pos)
|
||||
r = tp / float(real_pos)
|
||||
f = 2 * p * r / (p + r)
|
||||
return f * 100, p * 100, r * 100
|
||||
|
||||
|
||||
def computeDialogue(greedy, answer):
|
||||
examples = []
|
||||
for idx, (g, a) in enumerate(zip(greedy, answer)):
|
||||
examples.append((a[0], g, a[1], idx))
|
||||
#examples.sort()
|
||||
turn_request_positives = 0
|
||||
turn_goal_positives = 0
|
||||
joint_goal_positives = 0
|
||||
ldt = None
|
||||
for ex in examples:
|
||||
if ldt is None or ldt.split('_')[:-1] != ex[0].split('_')[:-1]:
|
||||
state, answer_state = {}, {}
|
||||
ldt = ex[0]
|
||||
delta_state = to_delta_state(ex[1])
|
||||
answer_delta_state = to_delta_state(ex[2])
|
||||
state = update_state(state, delta_state['inform'])
|
||||
answer_state = update_state(answer_state, answer_delta_state['inform'])
|
||||
if dict_cmp(state, answer_state):
|
||||
joint_goal_positives += 1
|
||||
if delta_state['request'] == answer_delta_state['request']:
|
||||
turn_request_positives += 1
|
||||
if dict_cmp(delta_state['inform'], answer_delta_state['inform']):
|
||||
turn_goal_positives += 1
|
||||
|
||||
joint_goal_em = joint_goal_positives / len(examples) * 100
|
||||
turn_request_em = turn_request_positives / len(examples) * 100
|
||||
turn_goal_em = turn_goal_positives / len(examples) * 100
|
||||
answer = [(x[-1], x[-2]) for x in examples]
|
||||
#answer.sort()
|
||||
answer = [[x[1]] for x in answer]
|
||||
return joint_goal_em, turn_request_em, turn_goal_em, answer
|
||||
|
||||
|
||||
def compute_metrics(data, rouge=False, bleu=False, corpus_f1=False, logical_form=False, dialogue=False):
|
||||
greedy = [datum[0] for datum in data] # prediction
|
||||
answer = [datum[1] for datum in data] # ground truth
|
||||
for i in range(3):
|
||||
pred, tgt = data[i]
|
||||
print('pred', pred,'tgt',tgt, flush=True)
|
||||
metric_keys = []
|
||||
metric_values = []
|
||||
if logical_form:
|
||||
lfem, answer = computeLFEM(greedy, answer)
|
||||
metric_keys += ['lfem']
|
||||
metric_values += [lfem]
|
||||
em = computeEM(greedy, answer)
|
||||
metric_keys.append('em')
|
||||
metric_values.append(em)
|
||||
norm_greedy = [normalize_text(g) for g in greedy]
|
||||
norm_answer = [[normalize_text(a) for a in ans] for ans in answer]
|
||||
nf1 = computeF1(norm_greedy, norm_answer)
|
||||
nem = computeEM(norm_greedy, norm_answer)
|
||||
metric_keys.extend(['nf1', 'nem'])
|
||||
metric_values.extend([nf1, nem])
|
||||
if corpus_f1:
|
||||
corpus_f1, precision, recall = computeCF1(norm_greedy, norm_answer)
|
||||
metric_keys += ['corpus_f1', 'precision', 'recall']
|
||||
metric_values += [corpus_f1, precision, recall]
|
||||
if dialogue:
|
||||
joint_goal_em, request_em, turn_goal_em, answer = computeDialogue(greedy, answer)
|
||||
avg_dialogue = (joint_goal_em + request_em) / 2
|
||||
metric_keys += ['joint_goal_em', 'turn_request_em', 'turn_goal_em', 'avg_dialogue']
|
||||
metric_values += [joint_goal_em, request_em, turn_goal_em, avg_dialogue]
|
||||
metric_dict = collections.OrderedDict(list(zip(metric_keys, metric_values)))
|
||||
return metric_dict
|
||||
|
||||
task2metric_dict = {'woz.en':'avg_dialogue',
|
||||
'wikisql':'lfem',
|
||||
'squad':'nf1',
|
||||
'sst':'em',
|
||||
'srl':'nf1',
|
||||
'ag':'em',
|
||||
'yahoo':'em',
|
||||
'yelp':'em',
|
||||
'amazon':'em',
|
||||
'dbpedia':'em'
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
import json
|
||||
import os, sys
|
||||
|
||||
oridir="DATA_Divided/DATA_100/ag"
|
||||
outpath="DATA_Divided/DATA_100/ag/pretrain.txt"
|
||||
|
||||
datatype_list=['label_train','unlabel_train']
|
||||
with open(outpath,'w') as fw:
|
||||
for datatype in datatype_list:
|
||||
datapath = os.path.join(oridir,datatype+'.json')
|
||||
with open(datapath,'r') as f:
|
||||
data = [json.loads(i) for i in f.readlines()]
|
||||
for row in data:
|
||||
# print(json.dumps(row['input'].strip('"'), ensure_ascii=False),file=fw)
|
||||
print(row['input'.strip('"')], file=fw)
|
||||
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,25 @@
|
|||
import random
|
||||
|
||||
# seed_list = [2, 102, 22, 32, 42]
|
||||
seed_list = [2, 32, 42]
|
||||
# intent_list = ['atis', 'banking', 'clinc', 'hwu', 'snips', 'top_split1', 'top_split2' ,'top_split3']
|
||||
# res_order=open('six_order_intent.txt','w')
|
||||
# prelist = ['atis', 'dstc8', 'mit_movie_eng', 'snips', 'mit_movie_trivia', 'mit_restaurant']
|
||||
# prelist = ['yahoo','yelp','ag','dbpedia','amazon']
|
||||
# prelist = ['srl','sst','wikisql','woz.en','squad']
|
||||
prelist = ['srl','sst','wikisql','woz.en','squad','yahoo','yelp','ag','dbpedia','amazon']
|
||||
# res_order=open('text_classification_order.txt','w')
|
||||
res_order=open('mix_order.txt','w')
|
||||
|
||||
res_list = []
|
||||
for i in range(len(seed_list)):
|
||||
random.seed(seed_list[i])
|
||||
random.shuffle(prelist)
|
||||
# print(intent_list)
|
||||
# res_list.append(intent_list)
|
||||
# print(intent_list,file=res_order)
|
||||
print('('+' '.join(prelist)+')',file=res_order)
|
||||
# print('\n',file=res_order)
|
||||
|
||||
res_order.close()
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
adapter_transformers==3.0.0
|
||||
datasets==2.7.1
|
||||
flax==0.6.2
|
||||
GPUtil==1.4.0
|
||||
huggingface_hub==0.11.0
|
||||
jax==0.3.25
|
||||
nltk==3.6.7
|
||||
numpy==1.19.5
|
||||
optax==0.1.4
|
||||
rouge==1.0.1
|
||||
scikit_learn==1.1.3
|
||||
torch==1.8.2
|
||||
tqdm==4.64.0
|
||||
transformers==4.2.2
|
|
@ -0,0 +1,262 @@
|
|||
gpuid=$1
|
||||
# K=$2
|
||||
K=100
|
||||
R=20 #* Ratio of unlabeled data over labeled data.
|
||||
# K=5000
|
||||
run=K${K}_U${R}_run
|
||||
ord=0
|
||||
# data_type=tc #!
|
||||
# data_type=decanlp
|
||||
data_type=mix
|
||||
exp=semi_${data_type}_$run
|
||||
task_type=semi #!
|
||||
use_unlabel=true
|
||||
meantc=true #* Use mean teacher for semi-supervised learning.
|
||||
echo exp: $exp
|
||||
echo data_type: $data_type
|
||||
echo K: $K U: $R
|
||||
echo order: $ord
|
||||
|
||||
# * tc ===========================================
|
||||
if [[ $data_type == tc ]]
|
||||
then
|
||||
if [[ $ord == 0 ]]
|
||||
then
|
||||
task_list=(ag yelp dbpedia amazon yahoo) #ord0
|
||||
elif [[ $ord == 1 ]]
|
||||
then
|
||||
task_list=(yahoo amazon ag dbpedia yelp) #ord1
|
||||
elif [[ $ord == 2 ]]
|
||||
then
|
||||
task_list=(ag dbpedia yahoo yelp amazon) #ord2
|
||||
elif [[ $ord == 3 ]]
|
||||
then
|
||||
task_list=(yahoo yelp amazon dbpedia ag) #ord3
|
||||
elif [[ $ord == 4 ]]
|
||||
then
|
||||
task_list=(dbpedia yelp amazon ag yahoo) #ord4
|
||||
fi
|
||||
max_input_len=150
|
||||
max_ans_len=20
|
||||
epochs=120
|
||||
# epochs=75
|
||||
train_bs=16
|
||||
gradient_accumulation_steps=1
|
||||
eval_bs=16
|
||||
feedback_threshold=0.05
|
||||
simi_tau=0.8
|
||||
# simi_tau=0.65
|
||||
# add_confidence_selection=
|
||||
add_confidence_selection=true
|
||||
fi
|
||||
|
||||
# * decanlp ======================================
|
||||
if [[ $data_type == decanlp ]]
|
||||
then
|
||||
if [[ $ord == 0 ]]
|
||||
then
|
||||
task_list=(wikisql sst woz.en squad srl) #ord0
|
||||
# task_list=(woz.en sst)
|
||||
elif [[ $ord == 1 ]]
|
||||
then
|
||||
task_list=(srl squad wikisql woz.en sst) #ord1
|
||||
elif [[ $ord == 2 ]]
|
||||
then
|
||||
task_list=(wikisql woz.en srl sst squad) #ord2
|
||||
elif [[ $ord == 3 ]]
|
||||
then
|
||||
task_list=(srl sst squad woz.en wikisql) #ord3
|
||||
elif [[ $ord == 4 ]]
|
||||
then
|
||||
task_list=(woz.en sst squad wikisql srl) #ord4
|
||||
fi
|
||||
max_input_len=512
|
||||
max_ans_len=100
|
||||
epochs=200
|
||||
# epochs=100 #!test
|
||||
train_bs=4
|
||||
gradient_accumulation_steps=4
|
||||
eval_bs=16
|
||||
feedback_threshold=0.1
|
||||
# feedback_threshold=0.000001 #!test
|
||||
simi_tau=0.6
|
||||
add_confidence_selection=true
|
||||
fi
|
||||
# * MIX ================================================
|
||||
if [[ $data_type == mix ]]
|
||||
then
|
||||
if [[ $ord == 0 ]]
|
||||
then
|
||||
task_list=(yahoo amazon woz.en squad yelp ag wikisql dbpedia sst srl) #ord0
|
||||
elif [[ $ord == 1 ]]
|
||||
then
|
||||
task_list=(yelp wikisql yahoo sst srl ag dbpedia woz.en squad amazon) #ord1
|
||||
elif [[ $ord == 2 ]]
|
||||
then
|
||||
task_list=(woz.en sst yahoo squad ag dbpedia amazon srl yelp wikisql) #ord2
|
||||
fi
|
||||
max_input_len=512
|
||||
max_ans_len=100
|
||||
epochs=200
|
||||
train_bs=4
|
||||
gradient_accumulation_steps=4
|
||||
eval_bs=16
|
||||
feedback_threshold=0.1
|
||||
simi_tau=0.6
|
||||
add_confidence_selection=true
|
||||
fi
|
||||
|
||||
# *----------------------------------------------
|
||||
echo task list: ${task_list[@]}
|
||||
|
||||
# random_initialization=true
|
||||
random_initialization=
|
||||
num_label=$K
|
||||
unlabel_ratio=$R
|
||||
use_task_pmt=true
|
||||
test_all=true
|
||||
# test_all= #!test
|
||||
# debug_use_unlabel=true
|
||||
debug_use_unlabel=
|
||||
# evaluate_zero_shot=true
|
||||
evaluate_zero_shot=
|
||||
|
||||
lr=2e-4
|
||||
# warmup_epoch=1
|
||||
warmup_epoch=2 # * warm up epoch
|
||||
freeze_plm=true
|
||||
# freeze_plm=
|
||||
evalstep=500000
|
||||
# test_overfit=true #! Use label_train for evaluation
|
||||
test_overfit=
|
||||
|
||||
# * Unlabel hyper-params.
|
||||
gen_replay=true #* Perform generative replay
|
||||
# gen_replay=
|
||||
pseudo_data_ratio=0.1
|
||||
# pseudo_data_ratio=0.01 #!test
|
||||
construct_memory=true #* Construct memory (Forward Aug)
|
||||
# construct_memory=
|
||||
backward_augment=true #* Backward Aug
|
||||
# backward_augment=
|
||||
back_kneighbors=3 # for backward_augment to retrieve the current unlabel memory
|
||||
# forward_augment=true #* Forward Aug
|
||||
forward_augment=
|
||||
kneighbors=3 #
|
||||
# kneighbors=5 # for forward_augment to retrieve the old memory
|
||||
select_adapter=
|
||||
# select_adapter=true
|
||||
|
||||
unlabel_amount='1'
|
||||
# unlabel_amount=$unlabel_ratio
|
||||
# unlabel_amount='5'
|
||||
# ungamma=0.1
|
||||
ungamma=0.01
|
||||
add_unlabel_lm_loss=true
|
||||
# add_unlabel_lm_loss=
|
||||
add_label_lm_loss=true
|
||||
# add_label_lm_loss=
|
||||
accum_grad_iter=$unlabel_amount
|
||||
# feedback_threshold=0.5
|
||||
# lm_lambda=0.25
|
||||
lm_lambda=0.5
|
||||
KD_temperature=2
|
||||
KD_term=1
|
||||
diff_question=true
|
||||
pseudo_tau=1.5
|
||||
num_aug=3 # for EDA input augmentation
|
||||
consistency=10
|
||||
consistency_rampup=30
|
||||
# consistency_rampup=10
|
||||
# ema_decay=0.90
|
||||
ema_decay=0.95
|
||||
stu_feedback=true
|
||||
input_aug=true
|
||||
rdrop=true
|
||||
model_aug=true
|
||||
# model_aug=
|
||||
# rdrop=
|
||||
|
||||
# *--------------------------------------------------
|
||||
# datadir=./data_lamol/TC
|
||||
if [[ $data_type == tc ]]
|
||||
then
|
||||
datadir=../../DATA/MYDATA_DIVIDED/TC/label_$K
|
||||
elif [[ $data_type == decanlp ]]
|
||||
then
|
||||
datadir=../../DATA/MYDATA_DIVIDED/decaNLP/label_$K
|
||||
elif [[ $data_type == mix ]]
|
||||
then
|
||||
datadir=../../DATA/MYDATA_DIVIDED/MIX/label_$K
|
||||
fi
|
||||
|
||||
output=outputs/${data_type}/$exp/ord$ord
|
||||
tb_log_dir=tb_logs/${data_type}/$exp/ord$ord
|
||||
log=logs/${data_type}/${exp}_ord${ord}.log
|
||||
err=logs/${data_type}/${exp}_ord${ord}.err
|
||||
mkdir -p logs/${data_type} outputs $output $tb_log_dir
|
||||
# *--------------------------------------------------
|
||||
# python_file=tc_train.py
|
||||
python_file=unitrain.py
|
||||
|
||||
# TODO: Running ================================================
|
||||
CUDA_VISIBLE_DEVICES=$gpuid \
|
||||
python $python_file \
|
||||
--gpu=$gpuid \
|
||||
--data_dir=$datadir \
|
||||
--output_dir=$output \
|
||||
--tb_log_dir=$tb_log_dir \
|
||||
--tasks ${task_list[*]} \
|
||||
--experiment=$exp \
|
||||
--num_train_epochs=$epochs \
|
||||
--train_batch_size=$train_bs \
|
||||
--eval_batch_size=$eval_bs \
|
||||
--eval_steps=$evalstep \
|
||||
--data_type=$data_type \
|
||||
--use_unlabel=$use_unlabel \
|
||||
--pseudo_tau=$pseudo_tau \
|
||||
--warmup_epoch=$warmup_epoch \
|
||||
--ungamma=$ungamma \
|
||||
--unlabel_amount=$unlabel_amount \
|
||||
--meantc=$meantc \
|
||||
--consistency=$consistency \
|
||||
--num_aug=$num_aug \
|
||||
--consistency_rampup=$consistency_rampup \
|
||||
--test_overfit=$test_overfit \
|
||||
--ema_decay=$ema_decay \
|
||||
--num_label=$num_label \
|
||||
--use_task_pmt=$use_task_pmt \
|
||||
--freeze_plm=$freeze_plm \
|
||||
--input_aug=$input_aug \
|
||||
--model_aug=$model_aug \
|
||||
--stu_feedback=$stu_feedback \
|
||||
--rdrop=$rdrop \
|
||||
--test_all=$test_all \
|
||||
--unlabel_ratio=$unlabel_ratio \
|
||||
--gen_replay=$gen_replay \
|
||||
--add_unlabel_lm_loss=$add_unlabel_lm_loss \
|
||||
--add_label_lm_loss=$add_label_lm_loss \
|
||||
--lr=$lr \
|
||||
--accum_grad_iter=$accum_grad_iter \
|
||||
--feedback_threshold=$feedback_threshold \
|
||||
--lm_lambda=$lm_lambda \
|
||||
--kneighbors=$kneighbors \
|
||||
--construct_memory=$construct_memory \
|
||||
--KD_term=$KD_term \
|
||||
--diff_question=$diff_question \
|
||||
--KD_temperature=$KD_temperature \
|
||||
--add_confidence_selection=$add_confidence_selection \
|
||||
--debug_use_unlabel=$debug_use_unlabel \
|
||||
--gradient_accumulation_steps=$gradient_accumulation_steps \
|
||||
--max_input_len=$max_input_len \
|
||||
--max_ans_len=$max_ans_len \
|
||||
--backward_augment=$backward_augment \
|
||||
--back_kneighbors=$back_kneighbors \
|
||||
--similarity_tau=$simi_tau \
|
||||
--select_adapter=$select_adapter \
|
||||
--pseudo_data_ratio=$pseudo_data_ratio \
|
||||
--forward_augment=$forward_augment \
|
||||
--evaluate_zero_shot=$evaluate_zero_shot \
|
||||
--random_initialization=$random_initialization \
|
||||
> $log 2> $err
|
||||
#! not log test
|
|
@ -0,0 +1,163 @@
|
|||
import os
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
# * New arguments for semi-supervised continual learning.
|
||||
parser.add_argument('--newmm_size', default=0.2, type=float, help='Different memory size for storing new task unlabeled data.')
|
||||
parser.add_argument('--random_initialization',default=False, type=bool)
|
||||
parser.add_argument('--evaluate_zero_shot',default=False, type=bool)
|
||||
parser.add_argument('--forward_augment',default=False, type=bool, help='Whether apply forward augmentation from old constructed memory.')
|
||||
parser.add_argument('--backward_augment', type=bool, default=False, help='Whether use the current task unlabeled data to augment the old tasks.')
|
||||
parser.add_argument('--select_adapter',default=False, type=bool, help='Whether select the previous adapter for the new task initialization.')
|
||||
parser.add_argument('--similarity_tau',default=0.6, type=float, help='Similarity threshold of cosine for KNN.')
|
||||
parser.add_argument('--back_kneighbors',default=5, type=int, help='Retrieve K nearest neighbors from the current unlabeled memory.')
|
||||
parser.add_argument('--diff_question', type=bool, default=False, help='If true, the dataset has different questions for each sample.')
|
||||
parser.add_argument('--debug_use_unlabel', default=False, type=bool, help='Always use unlabeled data for debugging.')
|
||||
parser.add_argument('--add_confidence_selection', type=bool, default=False, help='Whether add confidence selection for pseudo label.')
|
||||
parser.add_argument('--kneighbors', default=5, type=int, help='Retrieve K nearest neighbors from the memory.')
|
||||
parser.add_argument('--construct_memory',default=False, help='Construct memory for generated pseudo data from old tasks.')
|
||||
parser.add_argument('--aug_neighbors_from_memory', default=False, help='Whether augment retrieved neighbors inputs from memory.')
|
||||
parser.add_argument('--lm_lambda',type=float, default=0.5, help='The coefficient of language modeling loss.')
|
||||
parser.add_argument('--accum_grad_iter', type=bool, default=1, help='Accumulation steps for gradients.')
|
||||
parser.add_argument('--add_label_lm_loss', type=bool, default=False, help='Whether compute lm loss on the labeled inputs.')
|
||||
parser.add_argument('--add_unlabel_lm_loss', type=bool, default=False, help='Whether compute lm loss on the unlabeled inputs.')
|
||||
parser.add_argument('--pretrain_epoch',type=int, default=20, help='Pretrain first for e.g.20 epochs before finetuning.')
|
||||
parser.add_argument('--pretrain_first', type=bool, default=False, help='Whether to perform pretraining before each task finetuning.')
|
||||
parser.add_argument('--test_all', type=bool, default=False, help='Whether evaluate all test data after training the last epoch.')
|
||||
parser.add_argument('--add_pretrain_with_ft', type=bool, default=False, help='Whether add pretraining MLM loss during finetuning.')
|
||||
parser.add_argument('--rdrop', default=False, type=bool, help='Whether use R-drop as the consistency regularization, where dropout is used as model augmentation.')
|
||||
parser.add_argument('--stu_feedback', default=False,type=bool, help='Whether consisder feedback from student to teacher on labeled datasets.')
|
||||
parser.add_argument('--feedback_threshold', default=1e-1, type=float, help='Threshold of student feedback to determine whether to use teacher to teach the student.')
|
||||
parser.add_argument('--dropout_rate',default=0.3, type=float, help='Modify dropout rate for model augmentation with dropout.')
|
||||
parser.add_argument('--freeze_plm', default=False, type=bool, help='Whether to freeze the pretrained LM and only finetuning the added adapters.')
|
||||
parser.add_argument('--max_input_len',default=512, type=int, help='Max number of input tokens.')
|
||||
parser.add_argument('--max_ans_len', default=100, type=int, help='Max number of answer tokens.')
|
||||
parser.add_argument('--use_task_pmt',default=False, type=bool, help='Whether use task specific token as the prefix.')
|
||||
parser.add_argument('--use_ema', default=False, type=bool, help='Use EMA for fixmatch.')
|
||||
parser.add_argument('--fixmatch', default=False, type=bool, help='Use fixmatch to solve semi-supervised learning.')
|
||||
parser.add_argument('--num_label', default=500, type=int, help='Number of labeled data we use.')
|
||||
parser.add_argument('--unlabel_ratio', default=500, type=int, help='Ratio of unlabeled data over labeled data we use.')
|
||||
parser.add_argument('--logit_distance_cost', default=-1, type=float, help='let the student model have two outputs and use an MSE loss between the logits with the given weight (default: only have one output)')
|
||||
parser.add_argument('--test_overfit', default=False, type=bool, help='Test whether the model can overfit on labeled train dataset.')
|
||||
parser.add_argument('--num_aug', type=int, default=3, help='Number of augmented sentences per sample with EDA.')
|
||||
parser.add_argument('--ema_decay', default=0.999, type=float, metavar='ALPHA', help='ema variable decay rate (default: 0.999)')
|
||||
parser.add_argument('--consistency_rampup', type=int, default=30, metavar='EPOCHS', help='length of the consistency loss ramp-up')
|
||||
parser.add_argument('--consistency', type=float, default=100.0, metavar='WEIGHT', help='use consistency loss with given weight (default: None)')
|
||||
parser.add_argument('--meantc', type=bool, default=False, help='Whether to use mean teacher to deal with SSL.')
|
||||
parser.add_argument('--only_update_prompts', type=bool, default=False, help='Only update the prompt tokens params.')
|
||||
parser.add_argument('--pl_sample_num', type=int, default=1, help='Number of sampling for pseudo labeling.')
|
||||
parser.add_argument('--pl_sampling', type=bool, default=False, help='Perform pseudo labeling with sampling for unlabeled data. That means, sample several times for each example, and choose the one with the lowest ppl.')
|
||||
parser.add_argument('--unlabel_amount', type=str, default='1', help='How much unlabeled data used per labeled batch.')
|
||||
parser.add_argument('--ungamma', type=float, default=0.2, help='Weight added to the unlabeled loss.')
|
||||
parser.add_argument('--KD_term', type=float, default=1, help="Control how many teacher model signal is passed to student.")
|
||||
parser.add_argument('--KD_temperature', type=float, default=2.0, help="Temperature used to calculate KD loss.")
|
||||
parser.add_argument('--consist_regularize', type=bool, default=False, help='Whether to use consistency regularization for SSL.')
|
||||
parser.add_argument('--stu_epochs', type=int, default=1, help='Number of epochs for training student model.')
|
||||
parser.add_argument('--noisy_stu', type=bool, default=False, help='Whether to use noisy student self-training framework for SSL, generate pseudo labels for all unlabeled data once.')
|
||||
parser.add_argument('--online_noisy_stu', type=bool, default=False, help='Noisy student self-training framework with PL and CR. Perform pseudo labeling online.')
|
||||
parser.add_argument('--naive_pseudo_labeling', type=bool, default=False, help='Whether to use pseudo-labeling for semi-supervised learning.')
|
||||
parser.add_argument('--only_pseudo_selection', type=bool, default=False, help='Whether to select pseudo labels with high confidence.')
|
||||
parser.add_argument('--pretrain_ul_input', type=bool, default=False, help='Only pretrain the inputs of unlabeled data.')
|
||||
parser.add_argument('--use_unlabel', type=bool, default=False, help='Whether to use unlabeled data during training.')
|
||||
parser.add_argument('--use_infix', type=bool, default=False, help='Whether to use infix prompts.')
|
||||
parser.add_argument('--use_prefix', type=bool, default=False, help='Whether to use prefix prompts.')
|
||||
parser.add_argument('--preseqlen', type=int, default=5, help='Number of prepended prefix tokens.')
|
||||
parser.add_argument('--train_dir',type=str, help='Divide part of labeled data to unlabeled.')
|
||||
parser.add_argument('--data_outdir',type=str, help='Directory of divided labeled and unlabeled data.')
|
||||
parser.add_argument('--pseudo_tau', type=float, default=0.6, help='Threshold to choose pseudo labels with high confidence.')
|
||||
parser.add_argument('--input_aug', type=bool, default=False, help='Whether to use EDA on inputs.')
|
||||
parser.add_argument('--model_aug', type=bool, default=False, help='Whether to use dropout on the model as the data augmentation.')
|
||||
parser.add_argument('--warmup_epoch', type=int, default=10, help='Not use unlabeled data before training certain epochs.')
|
||||
|
||||
# * Old arguments.
|
||||
parser.add_argument('--data_type',type=str,default='intent')
|
||||
parser.add_argument('--use_memory', default=False, type=bool, help="Whether store the learned latent variables z and other info into the memory.")
|
||||
parser.add_argument('--experiment', type=str)
|
||||
parser.add_argument('--tasks',nargs='+', default=['banking'])
|
||||
parser.add_argument("--data_dir", default="./PLL_DATA/", type=str, help="The path to train/dev/test data files.") # Default parameters are set based on single GPU training
|
||||
parser.add_argument("--output_dir", default="./output/dstc", type=str,help="The output directory where the model checkpoints and predictions will be written.")
|
||||
parser.add_argument("--tb_log_dir", default="./tb_logs/dstc", type=str,help="The tensorboard output directory.")
|
||||
parser.add_argument("--res_dir",default="./res", type=str,help="The path to save scores of experiments.")
|
||||
parser.add_argument("--gene_batch_size", default=16, type=int) # For generation.
|
||||
parser.add_argument('--top_k', type=int, default=100)
|
||||
parser.add_argument('--top_p', type=float, default=0.95)
|
||||
parser.add_argument('--do_train',type=bool, default=True)
|
||||
parser.add_argument('--gen_replay',type=bool, default=False, help='Whether use generative replay to avoid forgetting.')
|
||||
parser.add_argument('--model_path', type=str, help='pretrained model path to local checkpoint')
|
||||
parser.add_argument('--generate_dur_train', type=bool, help='Generate reconstructed input utterances during training with CVAE.')
|
||||
parser.add_argument('--temperature',type=float, default=0.95)
|
||||
parser.add_argument('--pseudo_data_ratio', type=float, default=0.05, help="How many pseudo data to generate for each learned task")
|
||||
parser.add_argument('--lr', type=float, default=1e-4)
|
||||
parser.add_argument('--model_type', type=str, default='cvae', choices=['cvae', 'ae_vae_fusion'])
|
||||
parser.add_argument('--warmup', type=int, default=100, help="Amount of iterations to warmup, then decay. (-1 for no warmup and decay)")
|
||||
parser.add_argument("--train_batch_size", default=16, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--eval_batch_size", default=16, type=int, help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument('--switch-time', type=float, default=0, help="Percentage of iterations to spend on short sequence training.")
|
||||
parser.add_argument('--load', type=str, help='path to load model from') # , default='out/test/'
|
||||
parser.add_argument('--workers', default=1, type=int, metavar='N', help='number of data loading workers')
|
||||
parser.add_argument('--gpu', default=0, type=int)
|
||||
parser.add_argument('--no_gpu', action="store_true")
|
||||
|
||||
# * KL cost annealing, increase beta from beta_0 to 1 in beta_warmup steps
|
||||
parser.add_argument('--beta_0', default=1.00, type=float)
|
||||
parser.add_argument('--beta_warmup', type=int, default=50000)
|
||||
parser.add_argument('--cycle', type=int, default=1000)
|
||||
parser.add_argument("--filename",type=str,help="Data original file to be preprocessed.")
|
||||
parser.add_argument("--init_model_name_or_path", default="./dir_model/gpt2", type=str, help="Path to init pre-trained model")
|
||||
parser.add_argument("--num_workers", default=1, type=int, help="workers used to process data")
|
||||
parser.add_argument("--local_rank", help='used for distributed training', type=int, default=-1)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument("--ctx_max_len", default=128, type=int, help="Maximum input length for the sequence")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--weight_decay", default=0.00, type=float, help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=10, type=int, help="Total number of training epochs for each task.")
|
||||
parser.add_argument("--warmup_steps", default=-1, type=int, help="Linear warmup step. Will overwrite warmup_proportion.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Linear warmup over warmup_proportion * steps.")
|
||||
parser.add_argument("--nouse_scheduler", action='store_true', help="dont use get_linear_schedule_with_warmup, use unchanged lr")
|
||||
parser.add_argument('--logging_steps', type=int, default=10, help="Log every X updates steps.")
|
||||
parser.add_argument('--save_epochs', type=float, default=2, help="Save checkpoint every X epochs.")
|
||||
parser.add_argument('--eval_steps', type=int, default=20, help="Eval current model every X steps.")
|
||||
parser.add_argument('--eval_times_per_task', type=float, default=10, help="How many times to eval in each task, will overwrite eval_steps.")
|
||||
parser.add_argument('--seed', type=int, default=42, help="Seed for everything")
|
||||
parser.add_argument("--log_eval_res", default=True, help="Whether to log out results in evaluation process")
|
||||
parser.add_argument("--debug", action="store_true", help="Use debug mode")
|
||||
#
|
||||
args = parser. parse_args()
|
||||
args.log_file = os.path.join(args.output_dir, 'log.txt')
|
||||
|
||||
if args.local_rank in [0, -1]:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
os.makedirs(args.tb_log_dir, exist_ok=True)
|
||||
else:
|
||||
while (not os.path.isdir(args.output_dir)) or (not os.path.isdir(args.tb_log_dir)):
|
||||
pass
|
||||
|
||||
if args.debug:
|
||||
args.logging_steps = 1
|
||||
torch.manual_seed(0)
|
||||
torch.backends.cudnn.deterministric = True
|
||||
|
||||
# setup distributed training
|
||||
distributed = (args.local_rank != -1)
|
||||
if distributed:
|
||||
print(args.local_rank)
|
||||
# torch.cuda.set_device(0)
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
args.device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend='nccl', init_method='env://')
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
args.device = torch.device("cpu")
|
||||
args.num_train_epochs = {task: args.num_train_epochs for task in args.tasks}
|
||||
|
||||
return args
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
import json
|
||||
import os
|
||||
from settings import parse_args
|
||||
|
||||
args = parse_args()
|
||||
score_all_time = []
|
||||
with open(os.path.join(args.output_dir, "metrics.json"),"r") as f:
|
||||
score_all_time = [json.loads(row) for row in f.readlines()]
|
||||
|
||||
all_scores = []
|
||||
for row in score_all_time:
|
||||
value = list(row.values())[0]['intent_acc']
|
||||
all_scores.append(value)
|
||||
|
||||
print('%.5f'%(max(all_scores)))
|
|
@ -0,0 +1,5 @@
|
|||
(ag yelp dbpedia amazon yahoo)
|
||||
(yahoo amazon ag dbpedia yelp)
|
||||
(ag dbpedia yahoo yelp amazon)
|
||||
(yahoo yelp amazon dbpedia ag)
|
||||
(dbpedia yelp amazon ag yahoo)
|
|
@ -0,0 +1,60 @@
|
|||
from transformers import MarianMTModel, MarianTokenizer
|
||||
import torch
|
||||
|
||||
# target_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
|
||||
# tgt_tokz_path='out_BT/tgt_tokz'
|
||||
# tgt_model_path='out_BT/tgt_model'
|
||||
# target_tokenizer = MarianTokenizer.from_pretrained(tgt_tokz_path)
|
||||
# target_model = MarianMTModel.from_pretrained(tgt_model_path)
|
||||
|
||||
# # en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
|
||||
# en_tokz_path='out_BT/en_tokz'
|
||||
# en_model_path='out_BT/en_model'
|
||||
# en_tokenizer = MarianTokenizer.from_pretrained(en_tokz_path)
|
||||
# en_model = MarianMTModel.from_pretrained(en_model_path)
|
||||
|
||||
# target_tokenizer.save_pretrained('out_BT/tgt_tokz.pt')
|
||||
# target_model.save_pretrained('out_BT/tgt_model.pt')
|
||||
# en_tokenizer.save_pretrained('out_BT/en_tokz.pt')
|
||||
# en_model.save_pretrained('out_BT/en_model.pt')
|
||||
|
||||
# target_model, en_model = target_model.to('cuda'), en_model.to('cuda')
|
||||
|
||||
def translate(texts, model, tokenizer, language="fr"):
|
||||
# Prepare the text data into appropriate format for the model
|
||||
def template(
|
||||
text): return f"{text}" if language == "en" else f">>{language}<< {text}"
|
||||
src_texts = [template(text) for text in texts]
|
||||
|
||||
# Tokenize the texts
|
||||
encoded = tokenizer.prepare_seq2seq_batch(src_texts)
|
||||
# input_ids = tokenizer.encode(src_texts)
|
||||
for k, v in encoded.items():
|
||||
encoded[k] = torch.tensor(v)
|
||||
encoded = encoded.to('cuda')
|
||||
|
||||
# Generate translation using model
|
||||
translated = model.generate(**encoded)
|
||||
|
||||
# Convert the generated tokens indices back into text
|
||||
translated_texts = tokenizer.batch_decode(
|
||||
translated, skip_special_tokens=True)
|
||||
|
||||
return translated_texts
|
||||
|
||||
|
||||
def back_translate(texts, target_model, target_tokenizer, en_model, en_tokenizer, source_lang="en", target_lang="fr"):
|
||||
# Translate from source to target language
|
||||
fr_texts = translate(texts, target_model, target_tokenizer,
|
||||
language=target_lang)
|
||||
|
||||
# Translate from target language back to source language
|
||||
back_translated_texts = translate(fr_texts, en_model, en_tokenizer,
|
||||
language=source_lang)
|
||||
|
||||
return back_translated_texts
|
||||
|
||||
# en_texts = ['This is so cool', 'I hated the food', 'They were very helpful']
|
||||
# # en_texts = 'I love you.'
|
||||
# aug_texts = back_translate(en_texts, target_model, target_tokenizer, en_model, en_tokenizer, source_lang="en", target_lang="fr")
|
||||
# print(aug_texts)
|
|
@ -0,0 +1,42 @@
|
|||
import json, os, argparse
|
||||
from settings import parse_args
|
||||
import random
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--train_dir',type=str, help='Divide part of labeled data to unlabeled.')
|
||||
parser.add_argument('--data_outdir',type=str, help='Directory of divided labeled and unlabeled data.')
|
||||
parser.add_argument('--num_label',type=int,help='Number of labeled data for each training dataset.')
|
||||
args = parser.parse_args()
|
||||
|
||||
data_outdir = args.data_outdir
|
||||
train_dir = args.train_dir
|
||||
K = args.num_label
|
||||
|
||||
train_file = os.path.join(train_dir,'ripe_data','train.json')
|
||||
train_out_label = os.path.join(data_outdir,'label_train.json')
|
||||
train_out_unlabel = os.path.join(data_outdir,'unlabel_train.json')
|
||||
|
||||
with open(train_file,'r', encoding='utf-8') as f:
|
||||
data = [json.loads(i) for i in f.readlines()]
|
||||
|
||||
label_idx_list = random.sample(range(len(data)), K)
|
||||
label_data = [data[i] for i in range(len(data)) if i in label_idx_list]
|
||||
print(K,len(label_data))
|
||||
unlabel_data = [data[i] for i in range(len(data)) if i not in label_idx_list]
|
||||
|
||||
with open(train_out_label, 'w', encoding='utf-8') as f:
|
||||
for i in label_data:
|
||||
print(json.dumps(i, ensure_ascii=False), file=f)
|
||||
|
||||
with open(train_out_unlabel, 'w', encoding='utf-8') as f:
|
||||
for i in unlabel_data:
|
||||
print(json.dumps(i, ensure_ascii=False), file=f)
|
||||
|
||||
str_list = ['100','500','2000']
|
||||
for i in str_list:
|
||||
random.shuffle(unlabel_data)
|
||||
with open(os.path.join(data_outdir, 'unlabel_'+i+'_train.json'), 'w', encoding='utf-8') as f:
|
||||
for j in unlabel_data[:int(i)]:
|
||||
print(json.dumps(j, ensure_ascii=False), file=f)
|
||||
|
||||
print('Finish dividing',args.train_dir)
|
|
@ -0,0 +1,244 @@
|
|||
# Easy data augmentation techniques for text classification
|
||||
# Jason Wei and Kai Zou
|
||||
import nltk
|
||||
# nltk.download('wordnet', '/Users/yingxiu/Desktop/Xiu_CODE/Xiu/SCL/nltk_data')
|
||||
# nltk.download('omw-1.4')
|
||||
# from nltk_data.corpora import wordnet
|
||||
# from nltk.corpus import wordnet
|
||||
nltk.data.path.append('/Users/yingxiu/Desktop/Xiu_CODE/Xiu/SCL/nltk_data')
|
||||
import re
|
||||
import random
|
||||
from random import shuffle
|
||||
random.seed(1)
|
||||
|
||||
#stop words list
|
||||
stop_words = ['i', 'me', 'my', 'myself', 'we', 'our',
|
||||
'ours', 'ourselves', 'you', 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his',
|
||||
'himself', 'she', 'her', 'hers', 'herself',
|
||||
'it', 'its', 'itself', 'they', 'them', 'their',
|
||||
'theirs', 'themselves', 'what', 'which', 'who',
|
||||
'whom', 'this', 'that', 'these', 'those', 'am',
|
||||
'is', 'are', 'was', 'were', 'be', 'been', 'being',
|
||||
'have', 'has', 'had', 'having', 'do', 'does', 'did',
|
||||
'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or',
|
||||
'because', 'as', 'until', 'while', 'of', 'at',
|
||||
'by', 'for', 'with', 'about', 'against', 'between',
|
||||
'into', 'through', 'during', 'before', 'after',
|
||||
'above', 'below', 'to', 'from', 'up', 'down', 'in',
|
||||
'out', 'on', 'off', 'over', 'under', 'again',
|
||||
'further', 'then', 'once', 'here', 'there', 'when',
|
||||
'where', 'why', 'how', 'all', 'any', 'both', 'each',
|
||||
'few', 'more', 'most', 'other', 'some', 'such', 'no',
|
||||
'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too',
|
||||
'very', 's', 't', 'can', 'will', 'just', 'don',
|
||||
'should', 'now', '']
|
||||
|
||||
#cleaning up text
|
||||
|
||||
def get_only_chars(line):
|
||||
clean_line = ""
|
||||
|
||||
line = line.replace("’", "")
|
||||
line = line.replace("'", "")
|
||||
line = line.replace("-", " ") # replace hyphens with spaces
|
||||
line = line.replace("\t", " ")
|
||||
line = line.replace("\n", " ")
|
||||
line = line.lower()
|
||||
|
||||
for char in line:
|
||||
if char in 'qwertyuiopasdfghjklzxcvbnm ':
|
||||
clean_line += char
|
||||
else:
|
||||
clean_line += ' '
|
||||
|
||||
clean_line = re.sub(' +', ' ', clean_line) # delete extra spaces
|
||||
if clean_line[0] == ' ':
|
||||
clean_line = clean_line[1:]
|
||||
return clean_line
|
||||
|
||||
########################################################################
|
||||
# *Synonym replacement
|
||||
# Replace n words in the sentence with synonyms from wordnet
|
||||
########################################################################
|
||||
|
||||
|
||||
#for the first time you use wordnet
|
||||
# import nltk
|
||||
# nltk.download('wordnet')
|
||||
|
||||
def synonym_replacement(words, n):
|
||||
new_words = words.copy()
|
||||
random_word_list = list(
|
||||
set([word for word in words if word not in stop_words]))
|
||||
random.shuffle(random_word_list)
|
||||
num_replaced = 0
|
||||
for random_word in random_word_list:
|
||||
synonyms = get_synonyms(random_word)
|
||||
if len(synonyms) >= 1:
|
||||
synonym = random.choice(list(synonyms))
|
||||
new_words = [synonym if word == random_word else word for word in new_words]
|
||||
#print("replaced", random_word, "with", synonym)
|
||||
num_replaced += 1
|
||||
if num_replaced >= n: # only replace up to n words
|
||||
break
|
||||
|
||||
#this is stupid but we need it, trust me
|
||||
sentence = ' '.join(new_words)
|
||||
new_words = sentence.split(' ')
|
||||
|
||||
return new_words
|
||||
|
||||
|
||||
def get_synonyms(word):
|
||||
synonyms = set()
|
||||
for syn in wordnet.synsets(word):
|
||||
for l in syn.lemmas():
|
||||
synonym = l.name().replace("_", " ").replace("-", " ").lower()
|
||||
synonym = "".join(
|
||||
[char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm'])
|
||||
synonyms.add(synonym)
|
||||
if word in synonyms:
|
||||
synonyms.remove(word)
|
||||
return list(synonyms)
|
||||
|
||||
########################################################################
|
||||
# *Random deletion
|
||||
# Randomly delete words from the sentence with probability p
|
||||
########################################################################
|
||||
|
||||
|
||||
def random_deletion(words, p):
|
||||
|
||||
#obviously, if there's only one word, don't delete it
|
||||
if len(words) == 1:
|
||||
return words
|
||||
|
||||
#randomly delete words with probability p
|
||||
new_words = []
|
||||
for word in words:
|
||||
r = random.uniform(0, 1)
|
||||
if r > p:
|
||||
new_words.append(word)
|
||||
|
||||
#if you end up deleting all words, just return a random word
|
||||
if len(new_words) == 0 and len(words)>1:
|
||||
rand_int = random.randint(0, len(words)-1)
|
||||
return [words[rand_int]]
|
||||
|
||||
return new_words
|
||||
|
||||
########################################################################
|
||||
# *Random swap
|
||||
# Randomly swap two words in the sentence n times
|
||||
########################################################################
|
||||
|
||||
|
||||
def random_swap(words, n):
|
||||
new_words = words.copy()
|
||||
for _ in range(n):
|
||||
if len(new_words) > 1:
|
||||
new_words = swap_word(new_words)
|
||||
else:
|
||||
new_words = words
|
||||
return new_words
|
||||
|
||||
|
||||
def swap_word(new_words):
|
||||
random_idx_1 = random.randint(0, len(new_words)-1)
|
||||
random_idx_2 = random_idx_1
|
||||
counter = 0
|
||||
while random_idx_2 == random_idx_1:
|
||||
random_idx_2 = random.randint(0, len(new_words)-1)
|
||||
counter += 1
|
||||
if counter > 3:
|
||||
return new_words
|
||||
new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
|
||||
return new_words
|
||||
|
||||
########################################################################
|
||||
# *Random insertion
|
||||
# Randomly insert n words into the sentence
|
||||
########################################################################
|
||||
|
||||
|
||||
def random_insertion(words, n):
|
||||
new_words = words.copy()
|
||||
for _ in range(n):
|
||||
add_word(new_words)
|
||||
return new_words
|
||||
|
||||
|
||||
def add_word(new_words):
|
||||
synonyms = []
|
||||
counter = 0
|
||||
while len(synonyms) < 1:
|
||||
if len(new_words)>1:
|
||||
random_word = new_words[random.randint(0, len(new_words)-1)]
|
||||
synonyms = get_synonyms(random_word)
|
||||
counter += 1
|
||||
if counter >= 10:
|
||||
return
|
||||
random_synonym = synonyms[0]
|
||||
random_idx = random.randint(0, len(new_words)-1)
|
||||
new_words.insert(random_idx, random_synonym)
|
||||
|
||||
########################################################################
|
||||
# main data augmentation function
|
||||
########################################################################
|
||||
|
||||
|
||||
def eda(sentence, alpha_sr=0.0, alpha_ri=0.0, alpha_rs=0.1, p_rd=0.05, num_aug=5):
|
||||
|
||||
sentence = get_only_chars(sentence)
|
||||
words = sentence.split(' ')
|
||||
words = [word for word in words if word!='']
|
||||
num_words = len(words)
|
||||
|
||||
augmented_sentences = []
|
||||
num_new_per_technique = int(num_aug/4)+1
|
||||
|
||||
#sr
|
||||
if (alpha_sr > 0):
|
||||
n_sr = max(1, int(alpha_sr*num_words))
|
||||
for _ in range(num_new_per_technique):
|
||||
a_words = synonym_replacement(words, n_sr)
|
||||
augmented_sentences.append(' '.join(a_words))
|
||||
|
||||
#ri
|
||||
if (alpha_ri > 0):
|
||||
n_ri = max(1, int(alpha_ri*num_words))
|
||||
for _ in range(num_new_per_technique):
|
||||
a_words = random_insertion(words, n_ri)
|
||||
augmented_sentences.append(' '.join(a_words))
|
||||
|
||||
#rs
|
||||
if (alpha_rs > 0):
|
||||
n_rs = max(1, int(alpha_rs*num_words))
|
||||
for _ in range(num_new_per_technique):
|
||||
a_words = random_swap(words, n_rs)
|
||||
augmented_sentences.append(' '.join(a_words))
|
||||
|
||||
#rd
|
||||
if (p_rd > 0):
|
||||
for _ in range(num_new_per_technique):
|
||||
a_words = random_deletion(words, p_rd)
|
||||
augmented_sentences.append(' '.join(a_words))
|
||||
|
||||
augmented_sentences = [get_only_chars(sentence)
|
||||
for sentence in augmented_sentences]
|
||||
shuffle(augmented_sentences)
|
||||
#trim so that we have the desired number of augmented sentences
|
||||
if num_aug >= 1:
|
||||
augmented_sentences = augmented_sentences[:num_aug]
|
||||
else:
|
||||
keep_prob = num_aug / len(augmented_sentences)
|
||||
augmented_sentences = [s for s in augmented_sentences if random.uniform(0, 1) < keep_prob]
|
||||
#append the original sentence
|
||||
augmented_sentences.append(sentence)
|
||||
shuffle(augmented_sentences)
|
||||
return augmented_sentences
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
s = 'I like the cat with white and black fur.'
|
||||
print(eda(s))
|
|
@ -0,0 +1,47 @@
|
|||
import os
|
||||
import json
|
||||
import sys
|
||||
|
||||
ori_dir = 'lamol_data'
|
||||
task_names = ['ag','dbpedia','yelp','yahoo']
|
||||
for task in task_names:
|
||||
dir_path = os.path.join('data_lamol','TC',task)
|
||||
if not os.path.exists(task) and not os.path.isdir(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
# Convert train-data
|
||||
train_file = os.path.join(ori_dir,task+'_to_squad-train-v2.0.json')
|
||||
with open(train_file,'r', encoding='utf-8') as f:
|
||||
data = json.load(f)['data']
|
||||
|
||||
train_outfile = os.path.join(dir_path,'train.json')
|
||||
with open(train_outfile,'w', encoding='utf-8') as fw:
|
||||
for sample in data:
|
||||
sample_dict = {}
|
||||
x = sample['paragraphs'][0]['context'] + ' ' + sample['paragraphs'][0]['qas'][0]['question']
|
||||
y = sample['paragraphs'][0]['qas'][0]['answers'][0]['text']
|
||||
sample_dict['input'] = x
|
||||
sample_dict['output'] = y
|
||||
sample_dict['question'] = sample['paragraphs'][0]['qas'][0]['question']
|
||||
sample_dict['text_wo_question'] = sample['paragraphs'][0]['context']
|
||||
print(json.dumps(sample_dict, ensure_ascii=False), file=fw)
|
||||
|
||||
# Convert test-data
|
||||
test_file = os.path.join(ori_dir,task+'_to_squad-test-v2.0.json')
|
||||
with open(test_file,'r', encoding='utf-8') as f:
|
||||
data = json.load(f)['data']
|
||||
|
||||
test_outfile = os.path.join(dir_path,'test.json')
|
||||
with open(test_outfile,'w', encoding='utf-8') as fw:
|
||||
for sample in data:
|
||||
sample_dict = {}
|
||||
x = sample['paragraphs'][0]['context'] + ' ' + sample['paragraphs'][0]['qas'][0]['question']
|
||||
y = sample['paragraphs'][0]['qas'][0]['answers'][0]['text']
|
||||
sample_dict['input'] = x
|
||||
sample_dict['output'] = y
|
||||
sample_dict['question'] = sample['paragraphs'][0]['qas'][0]['question']
|
||||
sample_dict['text_wo_question'] = sample['paragraphs'][0]['context']
|
||||
print(json.dumps(sample_dict, ensure_ascii=False), file=fw)
|
||||
print('Finishing dealing with ',task,flush=True)
|
||||
|
||||
|
|
@ -0,0 +1,652 @@
|
|||
import torch
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import numpy as np
|
||||
from settings import parse_args
|
||||
from eda import *
|
||||
from pretrain import *
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
max_input_length_dict = {
|
||||
'woz.en': 128,
|
||||
'sst': 128,
|
||||
'srl': 128,
|
||||
'wikisql': 300,
|
||||
'squad':512,
|
||||
'ag':128,
|
||||
'yelp':256,
|
||||
'yahoo':128,
|
||||
'dbpedia':128,
|
||||
'amazon':200
|
||||
}
|
||||
|
||||
max_ans_length_dict = {
|
||||
'woz.en': 20,
|
||||
'sst': 10,
|
||||
'srl': 100,
|
||||
'wikisql': 100,
|
||||
'squad': 50,
|
||||
'ag':10,
|
||||
'yelp':10,
|
||||
'yahoo':10,
|
||||
'dbpedia':10,
|
||||
'amazon':10
|
||||
}
|
||||
|
||||
args = parse_args()
|
||||
class LBDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, task_name, tokz, data_path, max_input_len=100, special_token_ids=None):
|
||||
self.tokz = tokz
|
||||
self.data_path = data_path
|
||||
self.max_ans_len = args.max_ans_len
|
||||
self.special_token_ids = special_token_ids
|
||||
self.task_name = task_name
|
||||
self.max_input_len = max_input_length_dict[self.task_name]
|
||||
|
||||
with open(data_path, "r", encoding='utf-8') as f:
|
||||
ori_data = [json.loads(i) for i in f.readlines()]
|
||||
data = []
|
||||
for i in ori_data:
|
||||
data += self.parse_example(i)
|
||||
|
||||
self.data = []
|
||||
if len(data) > 0:
|
||||
self.data = self.data_tokenization(task_name, data)
|
||||
|
||||
if task_name in ['wikisql','woz.en'] and 'test' in self.data_path:
|
||||
with open(os.path.join(args.data_dir, task_name, 'answers.json'),'r') as f:
|
||||
self.answers = json.load(f)
|
||||
print(f'Number of answers is {len(self.answers)}')
|
||||
|
||||
if task_name in ['wikisql', 'squad'] and 'test' in self.data_path:
|
||||
self.targets = []
|
||||
for i in ori_data:
|
||||
self.targets.append(i['output'])
|
||||
|
||||
|
||||
def get_answer(self, target): # get answer text from intent
|
||||
# replace - and _ to make the text seems more natural
|
||||
if type(target) == list:
|
||||
target = target[0]
|
||||
if self.task_name == 'wikisql':
|
||||
ans = target
|
||||
else:
|
||||
ans = target.replace("-", " ").replace("_", " ")
|
||||
return ans
|
||||
|
||||
def parse_example(self, example):
|
||||
if self.task_name in ['wikisql','woz.en','squad','srl']:
|
||||
text = example['input']
|
||||
question = example['question']
|
||||
text_wo_question = example['text_wo_question']
|
||||
else:
|
||||
text = example['input'].replace('-',' ').replace('_',' ').replace('\\',' ')
|
||||
question = example['question'].replace('-',' ').replace('_',' ').replace('\\',' ')
|
||||
text_wo_question = example['text_wo_question'].replace('-',' ').replace('_',' ').replace('\\',' ')
|
||||
ans = self.get_answer(example['output'])
|
||||
if self.task_name in ['woz.en','squad','wikisql'] and 'test' in self.data_path:
|
||||
ans_list = example['output']
|
||||
return [(text, ans, question, text_wo_question, example['id'], ans_list)]
|
||||
return [(text, ans, question, text_wo_question)]
|
||||
|
||||
def per_tokenization(self, task_name, d):
|
||||
# print(d,flush=True)
|
||||
# input_text, ans_text = d
|
||||
if self.task_name in ['woz.en','squad','wikisql'] and 'test' in self.data_path:
|
||||
input_text, ans_text, question, text_wo_question, idx, ans_list = d
|
||||
else:
|
||||
input_text, ans_text, question, text_wo_question = d
|
||||
raw_text = text_wo_question
|
||||
|
||||
input_sen = task_name+':' + raw_text + '<QUES>' + question
|
||||
context_sen = input_sen + '<ANS>' # for evaluation
|
||||
ans_sen = ans_text
|
||||
# all_sen = context_sen + ans_sen
|
||||
all_sen = raw_text + '<QUES>'
|
||||
res_dict = {
|
||||
'all_sen': all_sen,
|
||||
'input_sen': input_sen,
|
||||
'context_sen': context_sen,
|
||||
'ans_sen': ans_sen,
|
||||
'question': question,
|
||||
'raw_text': raw_text
|
||||
}
|
||||
if self.task_name in ['woz.en','squad','wikisql'] and 'test' in self.data_path:
|
||||
res_dict['id'] = idx
|
||||
res_dict['ans_list'] = ans_list
|
||||
return res_dict
|
||||
|
||||
def data_tokenization(self, task_name, data):
|
||||
# print('data',data[:10],flush=True)
|
||||
data = [self.per_tokenization(task_name, i) for i in data]
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
def sort_by_index(self):
|
||||
self.data.sort(key=lambda x: x['id'])
|
||||
|
||||
def get_indices(self):
|
||||
return [d['id'] for d in self.data]
|
||||
|
||||
|
||||
class ULBDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, task_name, tokz, data_path, max_input_len=100, special_token_ids=None):
|
||||
self.tokz = tokz
|
||||
self.data_path = data_path
|
||||
self.max_ans_len = args.max_ans_len
|
||||
self.special_token_ids = special_token_ids
|
||||
self.task_name = task_name
|
||||
self.max_input_len = max_input_length_dict[self.task_name]
|
||||
|
||||
with open(data_path, "r", encoding='utf-8') as f:
|
||||
# data = [json.loads(i) for i in f.readlines()]
|
||||
# data = [self.parse_example(i) for i in data] # [utter, label]
|
||||
ori_data = [json.loads(i) for i in f.readlines()]
|
||||
data = []
|
||||
for i in ori_data:
|
||||
data += self.parse_example(i)
|
||||
|
||||
self.data = []
|
||||
if len(data) > 0:
|
||||
self.data = self.data_tokenization(task_name, data)
|
||||
|
||||
if task_name in ['wikisql','woz.en'] and 'test' in self.data_path:
|
||||
with open(os.path.join(args.data_dir, task_name, 'answers.json'),'r') as f:
|
||||
self.answers = json.load(f)
|
||||
|
||||
def get_answer(self, target): # get answer text from intent
|
||||
if type(target) == list:
|
||||
target = target[0]
|
||||
if self.task_name == 'wikisql':
|
||||
ans = target
|
||||
else:
|
||||
ans = target.replace("-", " ").replace("_", " ")
|
||||
return ans
|
||||
|
||||
def parse_example(self, example):
|
||||
if self.task_name in ['wikisql','woz.en','squad','srl']:
|
||||
text = example['input']
|
||||
question = example['question']
|
||||
text_wo_question = example['text_wo_question']
|
||||
else:
|
||||
text = example['input'].replace('-',' ').replace('_',' ').replace('\\',' ')
|
||||
question = example['question'].replace('-',' ').replace('_',' ').replace('\\',' ')
|
||||
text_wo_question = example['text_wo_question'].replace('-',' ').replace('_',' ').replace('\\',' ')
|
||||
ans = self.get_answer(example['output'])
|
||||
return [(text, ans, question, text_wo_question)]
|
||||
|
||||
def per_tokenization(self, task_name, d):
|
||||
input_text, ans_text, question, text_wo_question = d
|
||||
raw_text = text_wo_question
|
||||
|
||||
input_sen = task_name+':' + raw_text + '<QUES>' + question
|
||||
context_sen = input_sen + '<ANS>' # for evaluation
|
||||
ans_sen = ans_text
|
||||
# all_sen = context_sen + ans_sen
|
||||
all_sen = raw_text + '<QUES>'
|
||||
return {
|
||||
'all_sen': all_sen,
|
||||
'input_sen': input_sen,
|
||||
'context_sen': context_sen,
|
||||
'ans_sen': ans_sen,
|
||||
'question': question,
|
||||
'raw_text': raw_text
|
||||
}
|
||||
|
||||
def data_tokenization(self, task_name, data):
|
||||
data = [self.per_tokenization(task_name, i) for i in data]
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
class MemoryDataset(ULBDataset):
|
||||
def __init__(self, res):
|
||||
self.max_ans_len = args.max_ans_len
|
||||
self.max_input_len = args.max_input_len
|
||||
self.data = res['batch_text'] #list of dict
|
||||
|
||||
class PinnedBatch:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __getitem__(self, k):
|
||||
return self.data[k]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.data[key] = value
|
||||
|
||||
def pin_memory(self):
|
||||
for k in self.data.keys():
|
||||
if type(self.data[k])!=list:
|
||||
self.data[k] = self.data[k].pin_memory()
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
return self.data.__repr__()
|
||||
|
||||
def __str__(self):
|
||||
return self.data.__str__()
|
||||
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
|
||||
def pad_seq(seq, pad, max_len, pad_left=False):
|
||||
if pad_left:
|
||||
return [pad] * (max_len - len(seq)) + seq
|
||||
else:
|
||||
return seq + [pad] * (max_len - len(seq))
|
||||
|
||||
|
||||
class PadBatchSeq:
|
||||
def __init__(self, pad_id=0, tokz=None, task=None):
|
||||
self.pad_id = pad_id
|
||||
self.tokz = tokz
|
||||
self.task = task
|
||||
|
||||
def __call__(self, batch):
|
||||
if self.task is not None:
|
||||
max_input_len = max_input_length_dict[self.task]
|
||||
max_ans_len = max_ans_length_dict[self.task]
|
||||
else:
|
||||
max_input_len = 600
|
||||
max_ans_len = args.max_ans_len
|
||||
|
||||
input_sen = [i['input_sen'] for i in batch]
|
||||
context_sen = [i['context_sen'] for i in batch]
|
||||
all_sen = [i['all_sen'] for i in batch]
|
||||
|
||||
if 'ans_sen' in batch[0]:
|
||||
ans_sen = [i['ans_sen'] for i in batch]
|
||||
ans_encoding = self.tokz(ans_sen, padding='longest', max_length=max_ans_len, truncation=True, return_tensors='pt')
|
||||
raw_sen = [i['raw_text'] for i in batch]
|
||||
ques_sen = [i['question'] for i in batch]
|
||||
|
||||
input_encoding = self.tokz(input_sen, padding='longest', max_length=max_input_len, truncation=True, return_tensors='pt')
|
||||
all_encoding = self.tokz(all_sen, padding='longest', max_length=max_input_len, truncation=True, return_tensors='pt')
|
||||
context_encoding = self.tokz(context_sen, padding='longest', max_length=max_input_len, truncation=True, return_tensors='pt')
|
||||
raw_text_encoding = self.tokz(raw_sen, padding='longest', max_length=max_input_len, truncation=True, return_tensors='pt')
|
||||
question_encoding = self.tokz(ques_sen, padding='longest', max_length=max_input_len, truncation=True, return_tensors='pt')
|
||||
|
||||
res = {}
|
||||
res["input_id"], res['input_mask'] = input_encoding.input_ids, input_encoding.attention_mask
|
||||
res["all_id"], res['all_mask'] = all_encoding.input_ids, all_encoding.attention_mask
|
||||
res["context_id"], res['context_mask'] = context_encoding.input_ids, context_encoding.attention_mask
|
||||
if 'ans_sen' in batch[0]:
|
||||
res["ans_id"], res['ans_mask'] = ans_encoding.input_ids, ans_encoding.attention_mask
|
||||
res["raw_id"], res['raw_mask'] = raw_text_encoding.input_ids, raw_text_encoding.attention_mask
|
||||
res["ques_id"], res['ques_mask'] = question_encoding.input_ids, question_encoding.attention_mask
|
||||
res['batch_text'] = batch
|
||||
|
||||
return PinnedBatch(res)
|
||||
|
||||
# * Information for different tasks.
|
||||
if args.data_type == 'intent':
|
||||
TASK2INFO = {
|
||||
"banking": {
|
||||
"dataset_class": LBDataset,
|
||||
"dataset_folder": "banking",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"clinc": {
|
||||
"dataset_class": LBDataset,
|
||||
"dataset_folder": "clinc",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"hwu": {
|
||||
"dataset_class": LBDataset,
|
||||
"dataset_folder": "hwu",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"atis": {
|
||||
"dataset_class": LBDataset,
|
||||
"dataset_folder": "atis",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"tod": {
|
||||
"dataset_class": LBDataset,
|
||||
"dataset_folder": "FB_TOD_SF",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"snips": {
|
||||
"dataset_class": LBDataset,
|
||||
"dataset_folder": "snips",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"top_split1": {
|
||||
"dataset_class": LBDataset,
|
||||
"dataset_folder": "top_split1",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"top_split2": {
|
||||
"dataset_class": LBDataset,
|
||||
"dataset_folder": "top_split2",
|
||||
"task_type": "CLS",
|
||||
},
|
||||
"top_split3": {
|
||||
"dataset_class": LBDataset,
|
||||
"dataset_folder": "top_split3",
|
||||
"task_type": "CLS",
|
||||
}
|
||||
}
|
||||
elif args.data_type =='slot':
|
||||
TASK2INFO = {
|
||||
"dstc8": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "dstc8",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"restaurant8": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "restaurant8",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"atis": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "atis",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"mit_movie_eng": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "mit_movie_eng",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"mit_movie_trivia": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "mit_movie_trivia",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"mit_restaurant": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "mit_restaurant",
|
||||
"task_type": "SlotTagging",
|
||||
},
|
||||
"snips": {
|
||||
"dataset_class": PromptSlotTaggingDataset,
|
||||
"dataset_folder": "snips",
|
||||
"task_type": "SlotTagging",
|
||||
}
|
||||
}
|
||||
elif args.data_type =='decanlp' or args.data_type == 'tc' or args.data_type == 'mix':
|
||||
TASK2INFO = {
|
||||
"ag": {
|
||||
"dataset_class": LBDataset,
|
||||
"dataset_folder": "ag",
|
||||
"task_type": "tc",
|
||||
},
|
||||
'dbpedia':{
|
||||
'dataset_class':LBDataset,
|
||||
'dataset_folder':'dbpedia',
|
||||
'task_type':'tc',
|
||||
},
|
||||
'yelp':{
|
||||
'dataset_class':LBDataset,
|
||||
'dataset_folder':'yelp',
|
||||
'task_type':'tc',
|
||||
},
|
||||
'yahoo':{
|
||||
'dataset_class':LBDataset,
|
||||
'dataset_folder':'yahoo',
|
||||
'task_type':'tc',
|
||||
},
|
||||
'snips':{
|
||||
'dataset_class':LBDataset,
|
||||
'dataset_folder':'snips',
|
||||
'task_type':'tc',
|
||||
},
|
||||
'amazon':{
|
||||
'dataset_class':LBDataset,
|
||||
'dataset_folder':'amazon',
|
||||
'task_type':'tc',
|
||||
},
|
||||
'woz.en':{
|
||||
'dataset_class':LBDataset,
|
||||
'dataset_folder':'woz.en',
|
||||
'task_type':'decanlp',
|
||||
},
|
||||
'squad':{
|
||||
'dataset_class':LBDataset,
|
||||
'dataset_folder':'squad',
|
||||
'task_type':'decanlp',
|
||||
},
|
||||
'sst':{
|
||||
'dataset_class':LBDataset,
|
||||
'dataset_folder':'sst',
|
||||
'task_type':'decanlp',
|
||||
},
|
||||
'srl':{
|
||||
'dataset_class':LBDataset,
|
||||
'dataset_folder':'srl',
|
||||
'task_type':'decanlp',
|
||||
},
|
||||
'wikisql':{
|
||||
'dataset_class':LBDataset,
|
||||
'dataset_folder':'wikisql',
|
||||
'task_type':'decanlp',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_unlabel_data(path, task):
|
||||
info = TASK2INFO[task]
|
||||
# data_path = os.path.join(path, info['dataset_folder'],'unlabel_train.json')
|
||||
if args.unlabel_ratio == -1:
|
||||
num_unlabel = 2000
|
||||
else:
|
||||
num_unlabel = args.num_label*args.unlabel_ratio
|
||||
data_path = os.path.join(path, info['dataset_folder'],str(num_unlabel)+'_unlabel_train.json')
|
||||
return data_path
|
||||
|
||||
def get_unlabel_dict(path, task, tokz, args):
|
||||
info = TASK2INFO[task]
|
||||
special_token_ids = {'ans_token':tokz.convert_tokens_to_ids('ANS:')}
|
||||
if args.data_type == 'intent':
|
||||
return LBDataset(task, tokz, path, args.max_input_len, special_token_ids=special_token_ids)
|
||||
else:
|
||||
return None
|
||||
|
||||
def write_mix_train_file(label_train_dataset, unlabel_train_dataset, out_file, oridir):
|
||||
datatype_list=['label_train','unlabel_train']
|
||||
with open(out_file,'w') as fw:
|
||||
for datatype in datatype_list:
|
||||
datapath = os.path.join(oridir,datatype+'.json')
|
||||
with open(datapath,'r') as f:
|
||||
data = [json.loads(i) for i in f.readlines()]
|
||||
for row in data:
|
||||
# print(json.dumps(row['input'].strip('"'), ensure_ascii=False),file=fw)
|
||||
print(row['input'.strip('"')], file=fw)
|
||||
|
||||
def create_dataloader_for_pretrain(mix_train_file, tokz, model, args):
|
||||
data_files = {}
|
||||
data_files['train'] = mix_train_file
|
||||
extension = mix_train_file.split(".")[-1]
|
||||
if extension == 'txt': extension = 'text'
|
||||
datasets = load_dataset(extension, data_files=data_files)
|
||||
train_dataset = datasets['train']
|
||||
|
||||
max_seq_length=128
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokz(examples[text_column_name], return_attention_mask=False)
|
||||
|
||||
column_names = train_dataset.column_names
|
||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||
tokenized_datasets = train_dataset.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=True,
|
||||
)
|
||||
# T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
|
||||
# To ensure that the input length is `max_seq_length`, we need to increase the maximum length
|
||||
# according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
|
||||
expanded_inputs_length, targets_length = compute_input_and_target_lengths(
|
||||
inputs_length=max_seq_length,
|
||||
noise_density=0.15,
|
||||
mean_noise_span_length=3.0,
|
||||
)
|
||||
|
||||
data_collator = DataCollatorForT5MLM(
|
||||
tokenizer=tokz,
|
||||
noise_density=0.15,
|
||||
mean_noise_span_length=3.0,
|
||||
input_length=max_seq_length,
|
||||
target_length=targets_length,
|
||||
pad_token_id=model.config.pad_token_id,
|
||||
decoder_start_token_id=model.config.decoder_start_token_id,
|
||||
)
|
||||
|
||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
|
||||
def group_texts(examples):
|
||||
# Concatenate all texts.
|
||||
concatenated_examples = {
|
||||
k: list(chain(*examples[k])) for k in examples.keys()}
|
||||
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
||||
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
||||
# customize this part to your needs.
|
||||
if total_length >= expanded_inputs_length:
|
||||
total_length = (
|
||||
total_length // expanded_inputs_length) * expanded_inputs_length
|
||||
# Split by chunks of max_len.
|
||||
result = {
|
||||
k: [t[i: i + expanded_inputs_length]
|
||||
for i in range(0, total_length, expanded_inputs_length)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
return result
|
||||
|
||||
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
|
||||
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
|
||||
# might be slower to preprocess.
|
||||
#
|
||||
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||
train_dataset = tokenized_datasets.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
load_from_cache_file=True,
|
||||
)
|
||||
train_sampler = torch.utils.data.RandomSampler(train_dataset)
|
||||
mix_dataloader = DataLoader(train_dataset, batch_size=16,
|
||||
sampler=train_sampler, num_workers=args.num_workers, pin_memory=True, collate_fn=data_collator)
|
||||
return mix_dataloader
|
||||
|
||||
|
||||
class MixedDataset(LBDataset):
|
||||
def __init__(self, unlabel_data, label_data, tokz, max_input_len=100):
|
||||
self.max_input_len = max_input_len
|
||||
self.tokz = tokz
|
||||
self.max_ans_len = args.max_ans_len
|
||||
self.special_token_ids = {'ans_token':tokz.convert_tokens_to_ids('ANS:')}
|
||||
|
||||
self.data = []
|
||||
self.data += unlabel_data
|
||||
self.data += label_data
|
||||
|
||||
class AugDataset(LBDataset):
|
||||
def __init__(self, task_name, tokz, label_data_path, unlabel_data_path, max_input_len=100, special_token_ids=None):
|
||||
self.tokz = tokz
|
||||
self.max_ans_len = args.max_ans_len
|
||||
self.special_token_ids = special_token_ids
|
||||
self.max_input_len = max_input_length_dict[task_name]
|
||||
with open(label_data_path, "r", encoding='utf-8') as f:
|
||||
ori_data = [json.loads(i) for i in f.readlines()]
|
||||
with open(unlabel_data_path, "r", encoding='utf-8') as f:
|
||||
ori_data += [json.loads(i) for i in f.readlines()]
|
||||
data = []
|
||||
for i in ori_data:
|
||||
data += self.parse_example(i)
|
||||
|
||||
self.data = []
|
||||
if len(data) > 0:
|
||||
self.data = self.data_tokenization(task_name, data)
|
||||
|
||||
def parse_example(self, example):
|
||||
text = example['userInput']['text']
|
||||
aug_text_list = eda(text, num_aug=args.num_aug)
|
||||
ans = self.get_answer(example['intent'])
|
||||
res = [(aug_text, ans) for aug_text in aug_text_list]
|
||||
res.append((text, ans))
|
||||
return res
|
||||
|
||||
def get_datasets(path, tasks, tokz, max_input_len=100, special_token_ids=None):
|
||||
res = {}
|
||||
|
||||
for task in tasks:
|
||||
res[task] = {}
|
||||
info = TASK2INFO[task]
|
||||
if args.unlabel_ratio == -1:
|
||||
num_unlabel = 2000
|
||||
else:
|
||||
num_unlabel = args.num_label * args.unlabel_ratio
|
||||
max_input_len = max_input_length_dict[task]
|
||||
|
||||
# ! Remember to restore this.
|
||||
res[task]['label_train'] = LBDataset(
|
||||
task, tokz, os.path.join(path, info['dataset_folder'], 'label_train.json'), max_input_len=max_input_len, special_token_ids=special_token_ids)
|
||||
if args.use_unlabel:
|
||||
res[task]['unlabel_train'] = ULBDataset(
|
||||
task, tokz, os.path.join(path, info['dataset_folder'], str(num_unlabel)+'_unlabel_train.json'), max_input_len=max_input_len, special_token_ids=special_token_ids)
|
||||
res[task]['val'] = TASK2INFO[task]['dataset_class'](task, tokz, os.path.join(path, info['dataset_folder'], 'test.json'), max_input_len=max_input_len, special_token_ids=special_token_ids)
|
||||
res[task]['test'] = TASK2INFO[task]['dataset_class'](task, tokz, os.path.join(path, info['dataset_folder'], 'test.json'), max_input_len=max_input_len, special_token_ids=special_token_ids)
|
||||
return res
|
||||
|
||||
|
||||
class PseudoDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, task, task_data, tokz, max_input_len=100, curr_data=None):
|
||||
'''
|
||||
task2data : {'task_name': [data1, data2, data3, ...]}
|
||||
'''
|
||||
self.max_input_len = max_input_length_dict[task]
|
||||
self.tokz = tokz
|
||||
self.max_ans_len = args.max_ans_len
|
||||
|
||||
self.data = []
|
||||
# for task in task2data:
|
||||
# self.data += self.data_tokenization(task, task2data[task])
|
||||
self.data += self.data_tokenization(task, task_data)
|
||||
if curr_data is not None:
|
||||
self.data += curr_data
|
||||
|
||||
def per_tokenization(self, task_name, d):
|
||||
# print(d,flush=True)
|
||||
# input_text, ans_text = d
|
||||
raw_text, ans_text, question = d
|
||||
|
||||
input_sen = task_name+':' + raw_text + '<QUES>' + question
|
||||
context_sen = input_sen + '<ANS>' # for evaluation
|
||||
ans_sen = ans_text
|
||||
# all_sen = context_sen + ans_sen
|
||||
all_sen = raw_text + '<QUES>'
|
||||
return {
|
||||
'all_sen': all_sen,
|
||||
'input_sen': input_sen,
|
||||
'context_sen': context_sen,
|
||||
'ans_sen': ans_sen,
|
||||
'question': question,
|
||||
'raw_text': raw_text
|
||||
}
|
||||
|
||||
def data_tokenization(self, task_name, data):
|
||||
data = [self.per_tokenization(task_name, i) for i in data]
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
|
@ -0,0 +1,472 @@
|
|||
from transformers import T5Tokenizer, T5Config
|
||||
import logging
|
||||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from sklearn.metrics import accuracy_score, f1_score
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
import time
|
||||
import gc
|
||||
import json
|
||||
import pickle
|
||||
import argparse
|
||||
import math
|
||||
import re
|
||||
import torch.nn as nn
|
||||
import torch.utils.data as data
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import copy
|
||||
import torch.nn.functional as F
|
||||
t5tokz = T5Tokenizer.from_pretrained('t5-base')
|
||||
|
||||
# ! Get answers given inputs.
|
||||
def get_answer(tokz, model, example, example_mask, max_ans_len, args=None, is_eval=False):
|
||||
temperature = args.temperature
|
||||
model.eval()
|
||||
device = 'cuda'
|
||||
|
||||
eos_token = tokz.eos_token_id
|
||||
pad_token = tokz.pad_token_id
|
||||
bad_words = t5tokz.additional_special_tokens
|
||||
bad_words_ids = tokz(bad_words, add_special_tokens=False).input_ids
|
||||
# bad_words_ids = [[badid] for badid in tokz.additional_special_tokens_ids]
|
||||
|
||||
# * Get sequence outputs.
|
||||
# print('Begin generating answers.',flush=True)
|
||||
outputs = model.generate(input_ids=example, attention_mask=example_mask, do_sample=False, eos_token_id=eos_token, pad_token_id=pad_token,
|
||||
output_scores=True, return_dict_in_generate=True, early_stopping=True, max_length=max_ans_len, bad_words_ids=bad_words_ids)
|
||||
|
||||
# print(outputs,flush=True)
|
||||
output_seq = outputs.sequences
|
||||
output_scores = outputs.scores
|
||||
if not is_eval:
|
||||
output_scores = torch.stack(list(output_scores), dim=0)
|
||||
output_scores = output_scores.permute(1, 0, 2) # [80,10,50258]
|
||||
# print(f'output_scores {output_scores.shape}',flush=True)
|
||||
# print(f'output_seq {output_seq.shape}',flush=True)
|
||||
# print('output_seq',output_seq.shape, 'output_scores',len(output_scores), output_scores[0].shape, flush=True)
|
||||
# generate one: example [16, 29], out_seq [16, 39], out_score[0] [16, 50258], out_score [max_length=input_ids.shape[-1], batch_size, vocab_size] tuple
|
||||
# generate many: out_seq [16, 5, 29], out_scores [16, 5, 10, 50258]
|
||||
# return output_seq[:,example.shape[1]:], output_scores
|
||||
return output_seq, output_scores
|
||||
|
||||
def textid_decode(text, eos, tokz, contain_eos=False):
|
||||
if eos in text:
|
||||
if contain_eos:
|
||||
idx = text.index(eos)
|
||||
text = text[:idx]
|
||||
else:
|
||||
text = text[:text.index(eos)]
|
||||
text_id = text
|
||||
len_text = len(text)
|
||||
text = tokz.decode(text).strip()
|
||||
return text, len_text, text_id
|
||||
|
||||
|
||||
def padding_convert(text_list, eos):
|
||||
tt_list = []
|
||||
for text in text_list:
|
||||
if eos in text:
|
||||
eos_indexs = [i for i, x in enumerate(text) if x == eos]
|
||||
# print('eos index', eos_indexs,flush=True)
|
||||
if len(eos_indexs) > 1:
|
||||
text = text[:eos_indexs[1]]
|
||||
tt_list.append(text)
|
||||
tt_lens = [len(i) for i in tt_list]
|
||||
tt_lens = torch.tensor(tt_lens, dtype=torch.long).to('cuda')
|
||||
tt_pad = torch.tensor([pad_seq(i, eos, max(tt_lens), pad_left=True)
|
||||
for i in tt_list], dtype=torch.long).to('cuda')
|
||||
return tt_pad, tt_lens
|
||||
|
||||
|
||||
def gen_pseudo_data(model, task, tokz, max_output_len=90, batch_size=4, target_count=100, output_file=None, args=None, question=None):
|
||||
device = 'cuda'
|
||||
task_prefix = task+':'
|
||||
task_prefix_id = tokz.encode(task_prefix)[0]
|
||||
# task_prefix_id = tokz.encode('Task:')[0]
|
||||
input_tokens = torch.full((batch_size,1), task_prefix_id).cuda()
|
||||
adapter_name = task.replace('.','_')
|
||||
model.set_active_adapters(adapter_name)
|
||||
|
||||
model = model.cuda()
|
||||
|
||||
pseudo_list = []
|
||||
utter_set = set()
|
||||
eos_token = tokz.eos_token_id
|
||||
pad_token = tokz.pad_token_id
|
||||
# bad_words = tokz.additional_special_tokens
|
||||
# bad_words_ids = [[badid] for badid in tokz.additional_special_tokens_ids]
|
||||
bad_words = t5tokz.additional_special_tokens
|
||||
bad_words_ids = tokz(bad_words, add_special_tokens=False).input_ids
|
||||
|
||||
if output_file is None:
|
||||
raise ValueError("Pseudo output file is not specified.")
|
||||
if output_file is not None:
|
||||
if not os.path.isdir(os.path.dirname(output_file)):
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
|
||||
# * Generate pseudo samples.
|
||||
input_set = set()
|
||||
while len(pseudo_list) < target_count:
|
||||
once_target_count = min(32, target_count)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
outputs = model.generate(input_ids=input_tokens, do_sample=True, eos_token_id=eos_token, pad_token_id=pad_token,
|
||||
output_scores=True, return_dict_in_generate=True, early_stopping=True, max_length=min(max_output_len, 256),
|
||||
num_return_sequences=once_target_count, bad_words_ids=bad_words_ids)
|
||||
output_seq = outputs.sequences
|
||||
output_list = output_seq.tolist()
|
||||
# print('output_list len',len(output_list), flush=True)
|
||||
# print('batch size',batch_size, flush=True)
|
||||
|
||||
for i in range(batch_size):
|
||||
# print("Test what's the first token:",output_list[i],flush=True)
|
||||
if len(output_list[i])<=1: continue
|
||||
output_id = output_list[i][1:]
|
||||
# print(output_id, flush=True)
|
||||
if eos_token in output_id:
|
||||
output_id = output_id[:output_id.index(eos_token)]
|
||||
output = tokz.decode(output_id, skip_special_tokens=True) #! Not contain '<QUES>' token
|
||||
# print(output, flush=True)
|
||||
|
||||
if len(output)>=2:
|
||||
input_sen = output.strip()
|
||||
context_sen = task_prefix + input_sen + '<QUES>' + question + '<ANS>'
|
||||
input_ids = torch.tensor([tokz.encode(context_sen)]).cuda()
|
||||
ans_id = model.generate(input_ids=input_ids, do_sample=False, eos_token_id=eos_token, pad_token_id=pad_token, bad_words_ids=bad_words_ids)
|
||||
output_sen = tokz.decode(ans_id[0], skip_special_tokens=True)
|
||||
else:
|
||||
input_sen, output_sen = None, None
|
||||
|
||||
if input_sen and output_sen:
|
||||
print('INPUT::',input_sen, '===> OUTPUT::', output_sen, flush=True)
|
||||
if input_sen not in input_set:
|
||||
input_set.add(input_sen) # avoid duplicate inputs.
|
||||
pseudo_list.append([input_sen, output_sen, question])
|
||||
|
||||
pseudo_list = pseudo_list[:target_count]
|
||||
|
||||
with open(output_file, 'w', encoding='utf8') as f:
|
||||
for input, output, _ in pseudo_list:
|
||||
print(json.dumps({'input': input, 'output': output}, ensure_ascii=False), file=f)
|
||||
model.train()
|
||||
return pseudo_list
|
||||
|
||||
|
||||
def get_all_priors(model, tokz, args):
|
||||
def pseudo_prompt(task):
|
||||
return f"In the \"{task}\" task, which intent category best describes: \""
|
||||
all_prior_info = {}
|
||||
for task in args.tasks:
|
||||
prompt_id = [tokz.bos_token_id]+tokz.encode(pseudo_prompt(task))+[tokz.eos_token_id]
|
||||
prompt_id = torch.LongTensor(prompt_id).to('cuda')
|
||||
prior_out = model.encoder(input_ids=prompt_id)
|
||||
prior_emb, _ = model.avg_attn(prior_out[0])
|
||||
prior_mean, prior_logvar = model.prior_mean(prior_emb), model.prior_logvar(prior_emb)
|
||||
all_prior_info[task]=(prior_mean, prior_logvar)
|
||||
return all_prior_info
|
||||
|
||||
def get_nearest_task(model, tokz, sample, all_prior_info, args):
|
||||
def pseudo_prompt(task):
|
||||
return f"In the \"{task}\" task, which intent category best describes: \""
|
||||
all_posteriors={}
|
||||
batch_size = len(sample['utter_id'])
|
||||
for task in args.tasks:
|
||||
prompt_id = [tokz.bos_token_id]+tokz.encode(pseudo_prompt(task))
|
||||
bt_prompt_id = torch.LongTensor([prompt_id for _ in range(batch_size)]).to('cuda')
|
||||
bt_px_id = torch.cat((bt_prompt_id,sample['utter_id'].to('cuda')),dim=1)
|
||||
bt_px_id = bt_px_id.to('cuda')
|
||||
if len(bt_px_id)!=batch_size:
|
||||
raise ValueError('Tensor concatenate is wrong.')
|
||||
# px_id = tokz.encode(prompt_id+sample['utter_id'].tolist())
|
||||
# prompt_id = torch.LongTensor(prompt_id).to('cuda')
|
||||
# px_id = torch.LongTensor(px_id).to('cuda')
|
||||
post_out = model.encoder(input_ids=bt_px_id)
|
||||
post_emb, _ = model.avg_attn(post_out[0])
|
||||
post_mean, post_logvar = model.post_mean(post_emb), model.post_logvar(post_emb)
|
||||
all_posteriors[task]=(post_mean, post_logvar)
|
||||
|
||||
min_kl = 1e10
|
||||
res_task = args.tasks[0]
|
||||
all_kl_dist = []
|
||||
for task in all_prior_info.keys():
|
||||
prior_mean, prior_logvar = all_prior_info[task]
|
||||
post_mean, post_logvar = all_posteriors[task]
|
||||
kl_dist = kl_divergence(post_mean, post_logvar, prior_mean, prior_logvar)
|
||||
all_kl_dist.append(kl_dist)
|
||||
if kl_dist < min_kl:
|
||||
min_kl = kl_dist
|
||||
res_task = task
|
||||
# print(all_kl_dist,flush=True)
|
||||
return res_task
|
||||
|
||||
def get_pred_context(tokz, pred_task_name, gt_task_name, sample):
|
||||
new_list = []
|
||||
for ss in sample['context_id'].tolist():
|
||||
# print('1',len(ss),flush=True)
|
||||
context = tokz.decode(ss)
|
||||
# print('1',context,flush=True)
|
||||
# new_context = text.replace(f'In the "{gt_task_name}" task, which intent category best describes: "', pred_task_name).strip()
|
||||
new_context = re.sub(gt_task_name,pred_task_name,context)
|
||||
# print('2',new_context,flush=True)
|
||||
new_context_id = tokz.encode(new_context)
|
||||
# print('2',len(new_context_id),flush=True)
|
||||
|
||||
new_list.append(new_context_id)
|
||||
# if len(new_context_id)!=len(new_list[0]):
|
||||
# raise ValueError('Lengths are not match in this batch.')
|
||||
context_lens = [len(i) for i in new_list]
|
||||
context_mask = torch.ByteTensor([[1] * context_lens[i] + [0] * (max(context_lens)-context_lens[i]) for i in range(len(context_lens))])
|
||||
new_res = torch.tensor([pad_seq(i, tokz.eos_token_id, max(context_lens), pad_left=True) for i in new_list], dtype=torch.long).to('cuda')
|
||||
new_lens = torch.tensor(context_lens,dtype=torch.long).to('cuda')
|
||||
# new_res = torch.LongTensor(new_list).to('cuda',non_blocking=True)
|
||||
return new_res, new_lens
|
||||
|
||||
def get_pseudo_labels_for_all(tokz, model, dataloader, task, sampling=False, args=None):
|
||||
out_dir = os.path.join(args.output_dir, f'{task}_unlabel_pseudo.json')
|
||||
out_dir_data = []
|
||||
data_path = get_unlabel_data(args.data_dir, task)
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
all_data = [json.loads(i) for i in f.readlines()]
|
||||
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
pred_ans_all = []
|
||||
input_all = []
|
||||
for i, data in enumerate(dataloader):
|
||||
sample = data['context_id'].to('cuda')
|
||||
sample_lens = data['context_lens'].to('cuda')
|
||||
pred_ans = get_answer(tokz, model, sample, sample_lens, max_ans_len=args.max_ans_len, args=args)
|
||||
pred_ans_all.append(pred_ans)
|
||||
pred_ans_all = communicate_tensor(pred_ans_all, pad_token=tokz.eos_token_id).tolist()
|
||||
for i in range(len(pred_ans_all)):
|
||||
res = {}
|
||||
res['userInput'] = {}
|
||||
res['userInput']['text'] = all_data[i]['userInput']['text']
|
||||
res['intent'] = tokz.decode(cut_eos(pred_ans_all[i], tokz.eos_token_id))
|
||||
out_dir_data.append(res)
|
||||
|
||||
with open(out_dir, 'w', encoding='utf-8') as f:
|
||||
for res in out_dir_data:
|
||||
print(json.dumps(res), file=f)
|
||||
model.train()
|
||||
|
||||
return None
|
||||
|
||||
def get_pseudo_labels_per_batch(tokz, model, data, sampling=False, args=None):
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
context = data['context_id'].to('cuda')
|
||||
context_lens = data['context_lens'].to('cuda')
|
||||
out_seqs, out_scores = get_answer(tokz, model, context, context_lens, max_ans_len=args.max_ans_len, args=args)
|
||||
|
||||
# print('out_seqs',out_seqs,'out_scores',out_scores,flush=True)
|
||||
# Get all lens.
|
||||
eos_token = tokz.eos_token_id
|
||||
ans_lens = []
|
||||
max_seq_len = out_seqs.size(-1)
|
||||
btz = out_seqs.size(0)
|
||||
out_seqs_list = out_seqs.tolist()
|
||||
# print(out_seqs_list,flush=True)
|
||||
all_ans_id = []
|
||||
for i in out_seqs_list:
|
||||
# print(i,flush=True)
|
||||
_, seq_len, seq_id = textid_decode(i, eos_token, tokz, contain_eos=True)
|
||||
ans_lens.append(seq_len) # Contain one EOS token.
|
||||
all_ans_id.append(seq_id)
|
||||
|
||||
# print('ans len', ans_lens, flush=True)
|
||||
# print('max_seq_len',max_seq_len, 'max(all_lens)',max(ans_lens), flush=True)
|
||||
# ans_scores = [[out_scores[context_lens[i]:all_lens[i]]] for i in range(btz)]
|
||||
# * Compute PPL.
|
||||
ppl_batch = compute_ppl(out_seqs, out_scores, tokz)
|
||||
|
||||
# * Prepare batch for training.
|
||||
spe_token_id = tokz.convert_tokens_to_ids('ANS:')
|
||||
batch_list = []
|
||||
for i in range(btz):
|
||||
info = {}
|
||||
input_id = data['input_id'][i][:data['input_lens'][i]]
|
||||
input_id = input_id.tolist()
|
||||
info['input_id'] = input_id
|
||||
info['ans_id'] = all_ans_id[i] + [tokz.eos_token_id]
|
||||
info['context_id'] = input_id + [spe_token_id]
|
||||
info['all_id'] = info['context_id'] + all_ans_id[i] + [tokz.eos_token_id]
|
||||
batch_list.append(info)
|
||||
|
||||
# Print some pseudo samples.
|
||||
if i==0:
|
||||
print('-'*10, 'Unlabeled sample with pseudo labels:', flush=True)
|
||||
print('input:', tokz.decode(info['input_id']), flush=True)
|
||||
print('all:',tokz.decode(info['all_id']), flush=True)
|
||||
print('-'*20,flush=True)
|
||||
# print(batch_list[0],flush=True)
|
||||
padbatch = PadBatchSeq()
|
||||
batch = padbatch(batch_list)
|
||||
print('ppl',ppl_batch.tolist(),flush=True)
|
||||
return ppl_batch, batch
|
||||
|
||||
def compute_ppl(out_seqs, out_scores, tokz):
|
||||
# * Get answer lens.
|
||||
ans_lens = []
|
||||
max_seq_len = out_seqs.size(-1)
|
||||
out_seqs_list = out_seqs.tolist()
|
||||
all_ans_id = []
|
||||
for i in out_seqs_list:
|
||||
_, seq_len, seq_id = textid_decode(i, tokz.eos_token_id, tokz, contain_eos=True)
|
||||
ans_lens.append(seq_len) # Contain one EOS token.
|
||||
all_ans_id.append(seq_id)
|
||||
|
||||
btz = out_seqs.size(0)
|
||||
ans_scores = out_scores # may get -inf
|
||||
# print(f'answer scores max {torch.max(ans_scores)} min {torch.min(ans_scores)}')
|
||||
# for j in range(btz):
|
||||
# score_per_sample = []
|
||||
# for l in range(max_seq_len):
|
||||
# score_per_sample.append(out_scores[l][j].tolist())
|
||||
# ans_scores.append(score_per_sample)
|
||||
# ans_scores = torch.tensor(ans_scores)
|
||||
# ans_scores = out_scores.permute(1, 0, 2)
|
||||
|
||||
log_softmax = nn.LogSoftmax(dim=-1) # for vocab_size
|
||||
log_soft = log_softmax(ans_scores)
|
||||
select_log_soft = []
|
||||
log_soft_flatten = log_soft.view(-1, ans_scores.shape[-1]).to('cuda')
|
||||
# log_soft_flatten = log_soft.view(-1, len(tokz)).to('cuda')
|
||||
out_seq_flatten = out_seqs.contiguous().view(-1).to('cuda')
|
||||
select_log_soft = log_soft_flatten[torch.arange(log_soft_flatten.shape[0]),out_seq_flatten]
|
||||
select_log_soft = select_log_soft.view(btz, max_seq_len)
|
||||
# ans_scores:[16, 10, 50258] log_soft [16, 10, 50258] select_log_soft [16, 10]
|
||||
ans_score_mask = [[1]* ans_lens[i]+[0]*(max_seq_len-ans_lens[i]) for i in range(btz)]
|
||||
ans_score_mask = torch.tensor(ans_score_mask).to('cuda')
|
||||
neg_loglh = torch.sum(-select_log_soft * ans_score_mask, dim=1) / torch.sum(ans_score_mask, dim=1)
|
||||
ppl_batch = torch.exp(neg_loglh).to('cuda')
|
||||
|
||||
return all_ans_id, ppl_batch
|
||||
|
||||
|
||||
def old_compute_ppl(out_seqs, out_scores, tokz):
|
||||
# * Get answer lens.
|
||||
ans_lens = []
|
||||
max_seq_len = out_seqs.size(-1)
|
||||
out_seqs_list = out_seqs.tolist()
|
||||
all_ans_id = []
|
||||
for i in out_seqs_list:
|
||||
_, seq_len, seq_id = textid_decode(i, tokz.eos_token_id, tokz, contain_eos=True)
|
||||
ans_lens.append(seq_len) # Contain one EOS token.
|
||||
all_ans_id.append(seq_id)
|
||||
|
||||
btz = out_seqs.size(0)
|
||||
ans_scores = out_scores
|
||||
# for j in range(btz):
|
||||
# score_per_sample = []
|
||||
# for l in range(max_seq_len):
|
||||
# score_per_sample.append(out_scores[l][j].tolist())
|
||||
# ans_scores.append(score_per_sample)
|
||||
# ans_scores = torch.tensor(ans_scores)
|
||||
# ans_scores = out_scores.permute(1, 0, 2)
|
||||
|
||||
log_softmax = nn.LogSoftmax(dim=-1) # for vocab_size
|
||||
log_soft = log_softmax(ans_scores)
|
||||
select_log_soft = []
|
||||
log_soft_flatten = log_soft.view(-1, len(tokz)).to('cuda')
|
||||
out_seq_flatten = out_seqs.contiguous().view(-1).to('cuda')
|
||||
select_log_soft = log_soft_flatten[torch.arange(log_soft_flatten.shape[0]),out_seq_flatten]
|
||||
select_log_soft = select_log_soft.view(btz, max_seq_len)
|
||||
# ans_scores:[16, 10, 50258] log_soft [16, 10, 50258] select_log_soft [16, 10]
|
||||
ans_score_mask = [[1]* ans_lens[i]+[0]*(max_seq_len-ans_lens[i]) for i in range(btz)]
|
||||
ans_score_mask = torch.tensor(ans_score_mask).to('cuda')
|
||||
neg_loglh = torch.sum(-select_log_soft * ans_score_mask, dim=1) / torch.sum(ans_score_mask, dim=1)
|
||||
ppl_batch = torch.exp(neg_loglh).to('cuda')
|
||||
|
||||
return all_ans_id, ppl_batch
|
||||
|
||||
def map_to_same_length(seq_list, tokz):
|
||||
max_seq_len = 10
|
||||
eos_token = tokz.eos_token_id
|
||||
new_seq_list = []
|
||||
for i in range(len(seq_list)):
|
||||
if len(seq_list[i]) < max_seq_len:
|
||||
new_seq_list.append(seq_list[i]+[eos_token] * (max_seq_len-len(seq_list[i])))
|
||||
return new_seq_list
|
||||
|
||||
def get_pseudo_labels_per_batch_with_sampling(tokz, model, data, args=None):
|
||||
eos_token = tokz.eos_token_id
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
context = data['context_id'].to('cuda')
|
||||
context_lens = data['context_lens'].to('cuda')
|
||||
out_seqs, out_scores = get_answer(tokz, model, context, context_lens, max_ans_len=args.max_ans_len, args=args)
|
||||
# generate many: out_seq [80, 29], out_scores [80, 10, 50258]
|
||||
|
||||
btz = args.train_batch_size
|
||||
all_seqs = []
|
||||
all_scores = [] # Save batch_size number.
|
||||
all_ans_id = []
|
||||
all_ppl = []
|
||||
all_res = []
|
||||
ans_id, ppl = compute_ppl(out_seqs, out_scores, tokz)
|
||||
|
||||
# output_seq = output_seq.split(args.pl_sample_num)
|
||||
# output_seq = torch.stack(output_seq)
|
||||
# output_scores = output_scores.split(args.pl_sample_num)
|
||||
# output_scores = torch.stack(output_scores)
|
||||
# out_seq [16, 5, 29], out_scores [16, 5, 10, 50258]
|
||||
ppl = torch.stack(torch.split(ppl, args.pl_sample_num)) # [16, 5]
|
||||
print('ppl shape after split', ppl.shape, flush=True)
|
||||
for i in range(btz):
|
||||
example_ppl = ppl[i,:]
|
||||
# scores shape torch.Size([5, 10, 50258]) seq shape torch.Size([5, 10])
|
||||
min_idx = torch.argmin(example_ppl)
|
||||
all_ans_id.append(ans_id[min_idx+i*args.pl_sample_num])
|
||||
all_ppl.append(example_ppl[min_idx])
|
||||
all_ppl = torch.tensor(all_ppl)
|
||||
|
||||
# * Prepare batch for training.
|
||||
spe_token_id = tokz.convert_tokens_to_ids('ANS:')
|
||||
batch_list = []
|
||||
for i in range(btz):
|
||||
info = {}
|
||||
input_id = data['input_id'][i][:data['input_lens'][i]]
|
||||
input_id = input_id.tolist()
|
||||
info['input_id'] = input_id
|
||||
info['ans_id'] = all_ans_id[i] + [tokz.eos_token_id]
|
||||
info['context_id'] = input_id + [spe_token_id]
|
||||
info['all_id'] = info['context_id'] + all_ans_id[i] + [tokz.eos_token_id]
|
||||
batch_list.append(info)
|
||||
|
||||
# Print some pseudo samples.
|
||||
if i==0:
|
||||
print('-'*10, 'Unlabeled sample with pseudo labels:', flush=True)
|
||||
print('input:', tokz.decode(info['input_id']), flush=True)
|
||||
print('all:',tokz.decode(info['all_id']), flush=True)
|
||||
print('-'*20,flush=True)
|
||||
# print(batch_list[0],flush=True)
|
||||
padbatch = PadBatchSeq()
|
||||
batch = padbatch(batch_list)
|
||||
print('ppl', all_ppl.tolist(),flush=True)
|
||||
|
||||
|
||||
if __name__ =='__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--tasks',nargs='+')
|
||||
parser.add_argument('--model_aug',default=False, type=bool)
|
||||
parser.add_argument('--freeze_plm', default=False, type=bool)
|
||||
args = parser. parse_args()
|
||||
model_path = 'outputs/tc_semi_meantc_K100_ag17/ema_model_last.pt'
|
||||
tokz_path = 'outputs/tc_semi_meantc_K100_ag17/model_cache/out'
|
||||
t5adapter = 'outputs/tc_semi_meantc_K100_ag17/t5adapter'
|
||||
tokz = T5Tokenizer.from_pretrained(tokz_path)
|
||||
config = T5Config()
|
||||
model = SSLLModel(config)
|
||||
model = model.initialize(t5adapter, 'ag', args)
|
||||
model = model.fresh('ag', args, ema=False)
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
output_file = 'outputs/tc_semi_meantc_K100_ag17/pseudo_ag.json'
|
||||
_ = gen_pseudo_data(model, 'ag', tokz, max_output_len=90, batch_size=30, target_count=10, output_file=output_file, args=None)
|
||||
|
|
@ -0,0 +1,252 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from eda import *
|
||||
import torch.nn as nn
|
||||
from unifymodel.dataset import *
|
||||
|
||||
class Curr_Unlabel_Memory(object):
|
||||
def __init__(self,keys=None,values=None, questions=None,args=None):
|
||||
if keys is None:
|
||||
self.keys = []
|
||||
self.values = []
|
||||
self.questions = []
|
||||
else:
|
||||
self.keys = keys
|
||||
self.values = values
|
||||
self.questions = questions
|
||||
|
||||
self.total_keys = len(self.keys)
|
||||
if args.unlabel_ratio == -1: # fix unlabel to 2k
|
||||
num_unlabel = 2000
|
||||
else:
|
||||
num_unlabel = args.num_label * args.unlabel_ratio
|
||||
self.memory_size = int(args.newmm_size * num_unlabel)
|
||||
|
||||
def push(self, keys, values, questions, args=None):
|
||||
# print(f'Lengths: key {len(keys)}, values: {len(values)}, questions: {len(questions)}', flush=True)
|
||||
for key, value, ques in zip(keys, values, questions):
|
||||
if self.total_keys < self.memory_size:
|
||||
assert type(self.keys) == list
|
||||
# print(f'key {key}, keys {self.keys}', flush=True)
|
||||
if key not in self.keys:
|
||||
self.keys.append(key.tolist())
|
||||
self.values.append(value)
|
||||
self.questions.append(ques)
|
||||
else:
|
||||
drop_idx = random.choice(range(self.memory_size))
|
||||
self.keys.pop(drop_idx)
|
||||
self.values.pop(drop_idx)
|
||||
self.questions.pop(drop_idx)
|
||||
self.keys.append(key.tolist())
|
||||
self.values.append(value)
|
||||
self.questions.append(ques)
|
||||
self.total_keys = len(self.keys) # Update the memory size.
|
||||
|
||||
def create_batch_to_augment_memory(old_task, old_memory, curr_memory, tokz=None, args=None):
|
||||
|
||||
querys = torch.tensor(old_memory['keys'])
|
||||
questions = old_memory['questions']
|
||||
neighbors = get_neighbors(querys, task_name=None, K=args.back_kneighbors, memory=curr_memory, args=args, questions=questions)
|
||||
aug_unlabel_batch = create_batch_from_memory(neighbors, tokz, args=args, task_name=old_task)
|
||||
return aug_unlabel_batch
|
||||
|
||||
|
||||
def cosine_similarity(v1, m2):
|
||||
# print(v1.shape, m2.shape)
|
||||
if len(m2.shape) == 1 and len(v1.shape) == 1:
|
||||
cos = nn.CosineSimilarity(dim=0)
|
||||
elif len(m2.shape)>1:
|
||||
v1 = v1.unsqueeze(0)
|
||||
cos = nn.CosineSimilarity(dim=1)
|
||||
else:
|
||||
print(f'v1 shape {v1.shape}, m2 shape {m2.shape}',flush=True)
|
||||
score = cos(v1, m2)
|
||||
return score
|
||||
|
||||
def get_sentence_embedding(model, batch, args=None):
|
||||
|
||||
model.set_active_adapters(None)
|
||||
batch_size = batch['raw_id'].size()[0]
|
||||
input_tokens = batch['raw_id'].cuda()
|
||||
attn_masks = batch['raw_mask'].cuda()
|
||||
outputs = model.transformer.encoder(input_ids=input_tokens, attention_mask=attn_masks, return_dict=True)
|
||||
pooled_emb = outputs.last_hidden_state
|
||||
sen_emb = torch.mean(pooled_emb, dim=1) # (batch_size, 768)
|
||||
# print('sentence emb shape', sen_emb.shape, flush=True)
|
||||
# if args.diff_question:
|
||||
all_questions = []
|
||||
all_values = []
|
||||
all_keys = []
|
||||
for i in range(batch_size):
|
||||
sample_ques = batch['batch_text'][i]['question']
|
||||
sample_value = batch['batch_text'][i]['raw_text']
|
||||
all_questions.append(sample_ques)
|
||||
all_values.append(sample_value)
|
||||
# all_keys.append(sen_emb[i].tolist())
|
||||
all_keys.append(sen_emb[i,:])
|
||||
# return sen_emb, all_values, all_questions
|
||||
return all_keys, all_values, all_questions
|
||||
|
||||
def construct_memory(task_name, model, batch, memory=None, args=None):
|
||||
|
||||
sen_embs, all_values, all_questions = get_sentence_embedding(model, batch, args=args)
|
||||
|
||||
if memory is None:
|
||||
memory = {}
|
||||
if task_name not in memory:
|
||||
memory[task_name] = {}
|
||||
memory[task_name]['keys'] = []
|
||||
memory[task_name]['values'] = []
|
||||
memory[task_name]['questions'] = []
|
||||
|
||||
batch_size = len(batch['batch_text'])
|
||||
for i in range(batch_size):
|
||||
key = sen_embs[i].tolist()
|
||||
# value = batch['raw_id'][i] #
|
||||
value = batch['batch_text'][i]['raw_text'] # * Store the raw sentence as the value into the memory.
|
||||
memory[task_name]['keys'].append(key)
|
||||
memory[task_name]['values'].append(value)
|
||||
memory[task_name]['questions'].append(all_questions[i])
|
||||
|
||||
return memory
|
||||
|
||||
def get_neighbors(querys, task_name=None, K=1, memory=None, args=None, questions=None):
|
||||
assert memory is not None
|
||||
|
||||
if task_name is not None: #* forward augment, memory is old unlabel memory
|
||||
keys = memory[task_name]['keys']
|
||||
values = memory[task_name]['values']
|
||||
# questions = memory[task_name]['questions'] # forward augment needs current unlabel questions
|
||||
|
||||
else: #* backward augment, memory is current_unlabel_memory
|
||||
keys = memory.keys
|
||||
values = memory.values
|
||||
# print(f'Number of the memory content is {len(keys)}',flush=True)
|
||||
assert questions is not None # should be old memory question
|
||||
# if args.diff_question and questions is None:
|
||||
# questions = memory.questions
|
||||
# questions = memory[task_name]['questions']
|
||||
|
||||
keys_tensor = torch.tensor(keys).cuda()
|
||||
|
||||
retrieved_samples = []
|
||||
for q in range(len(querys)):
|
||||
query = querys[q]
|
||||
query = query.cuda()
|
||||
# print('query shape', query.shape, 'key shape', keys_tensor.shape, flush=True)
|
||||
# similarity_scores = torch.matmul(keys_tensor, query.T)
|
||||
similarity_scores = cosine_similarity(query, keys_tensor)
|
||||
top_k = torch.topk(similarity_scores, K)
|
||||
top_k_idxs = top_k.indices.tolist()
|
||||
if top_k.values[-1] <= args.similarity_tau:
|
||||
continue
|
||||
# print('Top K largest consine similarity scores:', top_k.values.tolist(), flush=True)
|
||||
K_neighbor_keys = [keys[i] for i in top_k_idxs]
|
||||
neighbors = [values[i] for i in top_k_idxs]
|
||||
# retrieved_samples.append(neighbors)
|
||||
retrieved_samples.append((neighbors, questions[q])) # neighbors is a list of unlabeled inputs, questions[q] is a sentence.
|
||||
return retrieved_samples
|
||||
|
||||
def create_batch_from_memory(samples, tokz, args, task_name):
|
||||
input_sen, context_sen, all_sen = [], [], []
|
||||
task_prefix = task_name+':'
|
||||
raw_sen, ques_sen = [], []
|
||||
|
||||
if args.diff_question:
|
||||
for neighbors, question in samples:
|
||||
for sen in neighbors:
|
||||
# * Augments retrieved neighbors
|
||||
if args.aug_neighbors_from_memory:
|
||||
aug_text_list = eda(sen, num_aug=args.num_aug)
|
||||
for aug_text in aug_text_list:
|
||||
input_text = task_prefix + aug_text + '<QUES>' + question
|
||||
context_sen.append(input_text+'<ANS>')
|
||||
all_sen.append(aug_text + '<QUES>')
|
||||
raw_sen.append(aug_text)
|
||||
ques_sen.append(question)
|
||||
# all_sen.append(aug_text+'<ANS>'+ans_text)
|
||||
# ans_sen.append(ans_text)
|
||||
# * Directly use retrieved neighbors
|
||||
else:
|
||||
input_text = task_prefix + sen + '<QUES>' + question
|
||||
context_text = input_text+'<ANS>'
|
||||
all_text = sen + '<QUES>'
|
||||
input_sen.append(input_text)
|
||||
context_sen.append(context_text)
|
||||
all_sen.append(all_text)
|
||||
raw_sen.append(sen)
|
||||
ques_sen.append(question)
|
||||
|
||||
else:
|
||||
for sen in samples:
|
||||
# * Augments retrieved neighbors
|
||||
if args.aug_neighbors_from_memory:
|
||||
aug_text_list = eda(sen, num_aug=args.num_aug)
|
||||
for aug_text in aug_text_list:
|
||||
input_text = task_prefix + aug_text + '<QUES>' + question
|
||||
context_sen.append(input_text+'<ANS>')
|
||||
all_sen.append(aug_text + '<QUES>')
|
||||
raw_sen.append(aug_text)
|
||||
ques_sen.append(question)
|
||||
# all_sen.append(aug_text+'<ANS>'+ans_text)
|
||||
# ans_sen.append(ans_text)
|
||||
# * Directly use retrieved neighbors
|
||||
else:
|
||||
input_text = task_prefix + sen + '<QUES>' + question
|
||||
context_text = input_text+'<ANS>'
|
||||
all_text = sen + '<QUES>'
|
||||
input_sen.append(input_text)
|
||||
context_sen.append(context_text)
|
||||
all_sen.append(all_text)
|
||||
raw_sen.append(sen)
|
||||
ques_sen.append(question)
|
||||
|
||||
batch = []
|
||||
batch_list = []
|
||||
if len(input_sen)>=1:
|
||||
for l in range(len(input_sen)):
|
||||
sample_dict = {}
|
||||
sample_dict['all_sen'] = all_sen[l]
|
||||
sample_dict['input_sen'] = input_sen[l]
|
||||
sample_dict['context_sen'] = context_sen[l]
|
||||
sample_dict['question'] = ques_sen[l]
|
||||
sample_dict['raw_text'] = raw_sen[l]
|
||||
batch.append(sample_dict)
|
||||
|
||||
# input_encoding = tokz(input_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
|
||||
# all_encoding = tokz(all_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
|
||||
# context_encoding = tokz(context_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
|
||||
# raw_text_encoding = tokz(raw_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
|
||||
# question_encoding = tokz(ques_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
|
||||
|
||||
res = {}
|
||||
# res["input_id"], res['input_mask'] = input_encoding.input_ids, input_encoding.attention_mask
|
||||
# res["all_id"], res['all_mask'] = all_encoding.input_ids, all_encoding.attention_mask
|
||||
# res["context_id"], res['context_mask'] = context_encoding.input_ids, context_encoding.attention_mask
|
||||
# res["raw_id"], res['raw_mask'] = raw_text_encoding.input_ids, raw_text_encoding.attention_mask
|
||||
# res["ques_id"], res['ques_mask'] = question_encoding.input_ids, question_encoding.attention_mask
|
||||
res['batch_text'] = batch
|
||||
else:
|
||||
res = None
|
||||
|
||||
return res
|
||||
|
||||
def get_old_center_dict(old_memory, prev_tasks, args=None):
|
||||
center_dict = {}
|
||||
for prev_task in prev_tasks:
|
||||
old_keys = old_memory[prev_task]['keys'] # list of list
|
||||
old_keys_tensor = torch.tensor(old_keys).cuda()
|
||||
old_center = torch.mean(old_keys_tensor,dim=0)
|
||||
center_dict[prev_task] = old_center
|
||||
return center_dict
|
||||
|
||||
def adapter_selection(prev_tasks, curr_center, old_center_dict, args=None):
|
||||
distance_list = []
|
||||
for prev_task in prev_tasks:
|
||||
dist = cosine_similarity(curr_center, old_center_dict[prev_task])
|
||||
distance_list.append(dist)
|
||||
print(f'Distances among centers {distance_list}',flush=True)
|
||||
max_dist = max(distance_list)
|
||||
max_idx = distance_list.index(max_dist)
|
||||
adapter_name = prev_tasks[max_idx]
|
||||
return adapter_name
|
|
@ -0,0 +1,41 @@
|
|||
import torch
|
||||
from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration, PreTrainedModel, T5Model
|
||||
from transformers.adapters import T5AdapterModel
|
||||
from torch import nn
|
||||
|
||||
class SSLLModel(T5AdapterModel):
|
||||
# def __init__(self, config, args=None):
|
||||
# super(SSLLModel, self).__init__(config)
|
||||
# self.args = args
|
||||
# self.model = T5AdapterModel(config)
|
||||
# self.model.transformer is T5Model
|
||||
def initialize(self, model_path, task_name, args=None):
|
||||
self = self.from_pretrained('t5-base', cache_dir=model_path)
|
||||
self.config.vocab_size = self.config.vocab_size + len(args.tasks) + 2 # Add task-specific special tokens. 2 for <ANS> and <QUES>
|
||||
self.transformer.resize_token_embeddings(self.config.vocab_size)
|
||||
# print('old config',self.config,flush=True)
|
||||
# print('decoder_start_token_id',self.config.decoder_start_token_id, flush=True)
|
||||
if args.model_aug:
|
||||
self.config.dropout_rate = args.dropout_rate
|
||||
return self
|
||||
|
||||
def fresh(self, task_name, args=None, ema=False, initial=True):
|
||||
# * Add task-specific adapters.
|
||||
# self.add_adapter('adapter'+task_name)
|
||||
# self.add_seq2seq_lm_head('lm_head'+task_name)
|
||||
# self.train_adapter('adapter'+task_name)
|
||||
task_name = task_name.replace('.','_')
|
||||
if initial:
|
||||
self.add_adapter(task_name)
|
||||
self.add_seq2seq_lm_head(task_name)
|
||||
self.set_active_adapters(task_name)
|
||||
self.train_adapter(task_name) # Freeze other model params.
|
||||
|
||||
if not args.freeze_plm:
|
||||
self.freeze_model(False) # Finetuning all parameters. Default is True.
|
||||
# print('new config',self.config,flush=True)
|
||||
# print('decoder_start_token_id',self.config.decoder_start_token_id, flush=True)
|
||||
if ema:
|
||||
for param in self.parameters():
|
||||
param.detach_()
|
||||
return self
|
|
@ -0,0 +1,940 @@
|
|||
from unifymodel.utils import *
|
||||
from unifymodel.generate import *
|
||||
from unifymodel.model import SSLLModel
|
||||
from unifymodel.dataset import PadBatchSeq, TASK2INFO, LBDataset, get_datasets, get_unlabel_data, get_unlabel_dict, MixedDataset
|
||||
from unifymodel.dataset import *
|
||||
from unifymodel.memory import *
|
||||
|
||||
from transformers import T5Config, T5Tokenizer, T5Model,T5ForConditionalGeneration, T5AdapterModel
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.distributed as dist
|
||||
import os, time, gc, json, pickle, argparse, math, threading, shutil
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data as data
|
||||
from torch.nn import DataParallel
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import numpy as np
|
||||
import transformers
|
||||
from transformers import get_linear_schedule_with_warmup, Conv1D, AdamW
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import importlib
|
||||
import copy
|
||||
from collections import Counter
|
||||
from rouge import Rouge
|
||||
from metrics import *
|
||||
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
BatchEncoding,
|
||||
AutoModelForMaskedLM,
|
||||
FlaxT5ForConditionalGeneration,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerBase,
|
||||
T5Config,
|
||||
is_tensorboard_available,
|
||||
set_seed,
|
||||
T5ForConditionalGeneration,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
is_torch_tpu_available,
|
||||
)
|
||||
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
|
||||
from pretrain import *
|
||||
|
||||
def compute_solver_loss(model, loss_fn, total, args=None, kl_loss=None):
|
||||
input_tokens, input_mask, target_tokens, target_mask = total
|
||||
model = model.cuda()
|
||||
model.train()
|
||||
outputs = model(input_ids=input_tokens, attention_mask=input_mask, labels=target_tokens)
|
||||
loss = outputs[0]
|
||||
return loss
|
||||
|
||||
def compute_lm_loss(model, loss_fn, total, args=None, kl_loss=None):
|
||||
input_tokens, decoder_input_ids, labels = total
|
||||
model = model.cuda().train()
|
||||
decoder_input_ids = None
|
||||
outputs = model(input_ids=input_tokens, decoder_input_ids=decoder_input_ids, labels=labels)
|
||||
loss = outputs[0]
|
||||
return loss
|
||||
|
||||
def compute_feedback_loss(model, ema_model, loss_fn, total, args=None, kl_loss=None):
|
||||
input_tokens, input_mask, target_tokens, target_mask = total
|
||||
model, ema_model = model.cuda(), ema_model.cuda()
|
||||
model.train()
|
||||
ema_model.train()
|
||||
|
||||
# outputs = model(input_ids=input_tokens, attention_mask=input_mask, labels=target_tokens)
|
||||
# logits = outputs.logits.contiguous()
|
||||
# logits = logits.view(-1, logits.size(-1))
|
||||
|
||||
ema_outputs = ema_model(input_ids=input_tokens, attention_mask=input_mask, labels=target_tokens)
|
||||
# ema_logits = ema_outputs.logits.contiguous().detach()
|
||||
# ema_logits = ema_logits.view(-1, ema_logits.size(-1))
|
||||
# target_mask = (target_mask>0)
|
||||
# feedback_loss = kl_loss(ema_logits, logits, label_mask=target_mask) # Student model give feedback to teacher model.
|
||||
feedback_loss = ema_outputs.loss
|
||||
print('feedback_loss',feedback_loss.item(), flush=True)
|
||||
return feedback_loss
|
||||
|
||||
def compute_consistency_loss(model, ema_model, loss_fn, total, args=None, kl_loss=None, tokz=None, epoch=0, task_name=None):
|
||||
input_tokens, input_mask, target_tokens, target_mask = total
|
||||
model, ema_model = model.cuda(), ema_model.cuda()
|
||||
model.train()
|
||||
ema_model.train()
|
||||
btz = input_tokens.size()[0]
|
||||
# print(f'length of tokz {len(tokz)}', flush=True)
|
||||
assert task_name is not None
|
||||
max_ans_len = max_ans_length_dict[task_name]
|
||||
|
||||
# * Teacher prediction
|
||||
ema_seqs, ema_scores = get_answer(tokz, ema_model, input_tokens, input_mask, max_ans_len, args=args, is_eval=False)
|
||||
# print('ema_seqs shape',ema_seqs.shape,'input shape',input_tokens.shape,flush=True)
|
||||
# ema_seqs [192, 20], target [192, 4], input [192, 119]
|
||||
|
||||
# print(f'length of tokz {len(tokz)}', flush=True)
|
||||
ema_target_tokens = ema_seqs[:,1:]
|
||||
# print(f'ema tgt token 0 {ema_target_tokens[0]}',flush=True)
|
||||
ema_target_tokens[ema_target_tokens==tokz.pad_token_id] = -100
|
||||
# print(f'ema tgt token 1 {ema_target_tokens[0]}',flush=True)
|
||||
ema_target_tokens = ema_target_tokens.contiguous()
|
||||
ema_target_mask = torch.ones_like(ema_target_tokens)
|
||||
ema_target_mask[ema_target_tokens==-100] = 0
|
||||
ema_target_mask = (ema_target_mask > 0)
|
||||
|
||||
if args.add_confidence_selection:
|
||||
_, ppl = compute_ppl(ema_seqs[:,1:], ema_scores, tokz)
|
||||
ppl = torch.stack(torch.split(ppl, args.pl_sample_num)) # [16, 5]
|
||||
all_ppl = []
|
||||
for i in range(btz):
|
||||
example_ppl = ppl[i,:]
|
||||
# scores shape torch.Size([5, 10, 50258]) seq shape torch.Size([5, 10])
|
||||
min_idx = torch.argmin(example_ppl)
|
||||
all_ppl.append(example_ppl[min_idx])
|
||||
ppl_batch = torch.tensor(all_ppl)
|
||||
# print(f'ppl {ppl_batch}', flush=True)
|
||||
|
||||
# * Select high confidence outputs.
|
||||
greater_thresh = torch.le(ppl_batch, args.pseudo_tau).to('cuda')
|
||||
# print(f'confidence mask {greater_thresh.shape}',flush=True) # [48] num_aug * batch_size
|
||||
# print(f'ema tgt tokens {ema_target_tokens.shape}',flush=True)
|
||||
ema_target_tokens = (ema_target_tokens.T * greater_thresh).T
|
||||
ema_target_tokens[ema_target_tokens==tokz.pad_token_id] = -100
|
||||
|
||||
# * Student prediction
|
||||
outputs = model(input_ids=input_tokens, attention_mask=input_mask, labels=ema_target_tokens)
|
||||
logits = outputs.logits.contiguous()
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
consist_loss = outputs.loss
|
||||
|
||||
consistency_weight = get_current_consistency_weight(args, epoch)
|
||||
if args.rdrop:
|
||||
outputs2 = model(input_ids=input_tokens, attention_mask=input_mask, labels=ema_target_tokens)
|
||||
logits2 = outputs2.logits.contiguous()
|
||||
logits2 = logits2.view(-1, logits2.size(-1))
|
||||
rdrop_loss = 0.5 * kl_loss(logits, logits2, label_mask=ema_target_mask) + 0.5 * kl_loss(logits2, logits, label_mask=ema_target_mask)
|
||||
else:
|
||||
rdrop_loss = 0.0
|
||||
|
||||
consistency_loss = consistency_weight * (consist_loss + rdrop_loss)
|
||||
|
||||
print('consistency_kl_loss:',consistency_loss.item(),flush=True)
|
||||
# print(f'rdrop_loss: {rdrop_loss.item()}', flush=True)
|
||||
return consistency_loss
|
||||
|
||||
def get_model_inputs(data, global_step=10, tokz=None, model_type='cls', task=None, data_type='label'):
|
||||
if model_type=='cls':
|
||||
input_tokens = data['context_id']
|
||||
input_mask = data['context_mask'].contiguous()
|
||||
|
||||
if global_step < 3:
|
||||
print('-'*10, 'Solver Inputs and targets', '-'*10, flush=True)
|
||||
print('input:', tokz.decode(input_tokens[0]), flush=True)
|
||||
print('input id:', input_tokens[0], flush=True)
|
||||
print('attention_mask:', input_mask[0], flush=True)
|
||||
|
||||
if data_type == 'label':
|
||||
target_tokens = data['ans_id']
|
||||
target_mask = data['ans_mask'].contiguous()
|
||||
|
||||
if global_step < 3:
|
||||
print('target:',tokz.decode(target_tokens[0]),flush=True)
|
||||
print('target id:', target_tokens[0],flush=True)
|
||||
print('\n',flush=True)
|
||||
|
||||
target_tokens[target_tokens==tokz.pad_token_id] = -100
|
||||
target_tokens = target_tokens.contiguous()
|
||||
input_total = (input_tokens.cuda(), input_mask.cuda(), target_tokens.cuda(), target_mask.cuda())
|
||||
else:
|
||||
input_total = (input_tokens.cuda(), input_mask.cuda(), None, None)
|
||||
return input_total
|
||||
elif model_type=='lm':
|
||||
batch_size = data['input_id'].shape[0]
|
||||
assert task is not None
|
||||
task_prefix_id = tokz.encode(task+':')[0]
|
||||
# print(task_prefix_id, flush=True)
|
||||
input_tokens = torch.full((batch_size,1), task_prefix_id)
|
||||
decoder_input_ids = data['all_id'][:,:-1].contiguous()
|
||||
# labels = data['all_id'][:,1:].clone().detach()
|
||||
labels = data['all_id'][:,:].clone().detach()
|
||||
|
||||
if global_step < 3:
|
||||
print('-'*10,'LM Inputs and targets','-'*10, flush=True)
|
||||
print('input:',tokz.decode(input_tokens[0]),flush=True)
|
||||
print('input id:',input_tokens[0],flush=True)
|
||||
print('target:',tokz.decode(labels[0]),flush=True)
|
||||
print('target id:', labels[0],flush=True)
|
||||
print('\n',flush=True)
|
||||
|
||||
labels[labels==tokz.pad_token_id] = -100
|
||||
input_total = (input_tokens.cuda(), decoder_input_ids.cuda(), labels.cuda())
|
||||
return input_total
|
||||
|
||||
elif model_type=='pretrain':
|
||||
input_ids = data['input_ids'].cuda()
|
||||
labels = data['labels'].cuda()
|
||||
|
||||
if global_step <= 3:
|
||||
print('pretrain_input: ',tokz.decode(input_ids[0],flush=True))
|
||||
print('pretrain_label: ',tokz.decode(labels[0],flush=True))
|
||||
|
||||
labels[labels==tokz.pad_token_id] = -100
|
||||
labels = labels.contiguous()
|
||||
decoder_input_ids = data['decoder_input_ids'].cuda()
|
||||
return input_ids, labels, decoder_input_ids
|
||||
|
||||
def update_ema_variables(model, ema_model, alpha, global_step):
|
||||
# Use the true average until the exponential average is more correct
|
||||
# alpha = ema_decay
|
||||
alpha = min(1 - 1 / (global_step + 1), alpha)
|
||||
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
|
||||
ema_param.data = ema_param.data.cuda()
|
||||
param.data = param.data.cuda()
|
||||
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)
|
||||
# ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, args, tokz, datasets, logger, cache_dir, memory):
|
||||
# Other
|
||||
self.args = args
|
||||
self.datasets = datasets
|
||||
self.logger = logger
|
||||
self.rank = self.args.local_rank
|
||||
self.device = self.args.device
|
||||
self.global_step = 0
|
||||
self.task_step = 0
|
||||
distributed = False if self.rank == -1 else True
|
||||
self.task_dict = {k: v for v, k in enumerate(self.args.tasks)}
|
||||
|
||||
# Tensorboard log
|
||||
if self.rank in [0, -1]: # Only write logs on master node
|
||||
self.train_writer = SummaryWriter(os.path.join(args.tb_log_dir, 'train'), flush_secs=10)
|
||||
self.valid_writer = SummaryWriter(os.path.join(args.tb_log_dir, 'valid'))
|
||||
|
||||
self.config = T5Config()
|
||||
# print('T5 Config:',self.config,flush=True)
|
||||
self.tokz = tokz
|
||||
print('len tokz of t5', len(self.tokz), flush=True)
|
||||
|
||||
self.cache_dir = cache_dir
|
||||
self.save_folder = self.args.output_dir
|
||||
os.makedirs(self.save_folder, exist_ok=True)
|
||||
|
||||
# randomness
|
||||
np.random.seed(args.seed)
|
||||
prng = np.random.RandomState()
|
||||
torch.random.manual_seed(args.seed)
|
||||
gpu = not self.args.no_gpu
|
||||
if gpu: torch.cuda.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
self.loss_fn = nn.CrossEntropyLoss(reduction='none').to(self.device, non_blocking=True)
|
||||
self.kl_loss = KDLoss(KD_term=self.args.KD_term, T=self.args.KD_temperature)
|
||||
|
||||
# * Load datasets.
|
||||
self.data_loaders = {}
|
||||
for task in self.datasets:
|
||||
self.data_loaders[task] = {}
|
||||
train_sampler = torch.utils.data.RandomSampler(datasets[task]['label_train'])
|
||||
valid_sampler = None
|
||||
self.data_loaders[task]['label_train'] = DataLoader(datasets[task]['label_train'], batch_size=self.args.train_batch_size,
|
||||
sampler=train_sampler, num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=task))
|
||||
self.data_loaders[task]['val'] = DataLoader(datasets[task]['val'], batch_size=self.args.eval_batch_size,
|
||||
sampler=valid_sampler, num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=task))
|
||||
self.data_loaders[task]['test'] = DataLoader(datasets[task]['test'], batch_size=self.args.eval_batch_size,
|
||||
sampler=None, num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=None))
|
||||
|
||||
def _train(self, prev_model, prev_ema_model, curr_task, prev_tasks, all_tasks_res, question_dict):
|
||||
if len(prev_tasks)!=0:
|
||||
last_adapter_name = prev_tasks[-1].replace('.','_')
|
||||
curr_adapter_name = curr_task.replace('.','_')
|
||||
if prev_model is None:
|
||||
assert len(prev_tasks) == 0
|
||||
t5dir = os.path.join(self.args.output_dir,'t5basedir')
|
||||
t5adapter = os.path.join(self.args.output_dir,'t5adapter')
|
||||
model = SSLLModel(self.config)
|
||||
model = model.initialize(t5adapter, curr_task, self.args)
|
||||
model = model.fresh(curr_adapter_name, self.args, ema=False, initial=True)
|
||||
model.train()
|
||||
if self.args.meantc:
|
||||
ema_model = SSLLModel(self.config)
|
||||
ema_model = ema_model.initialize(t5adapter, curr_task, self.args)
|
||||
ema_model = ema_model.fresh(curr_adapter_name, self.args, ema=True, initial=True)
|
||||
ema_model = ema_model.cuda()
|
||||
ema_model.train()
|
||||
ema_num_tunable_params = sum(p.numel() for p in ema_model.parameters() if p.requires_grad)
|
||||
assert ema_num_tunable_params is 0
|
||||
# print('Total number of tunale parameters is %d'%num_tunable_params, flush=True)
|
||||
self.logger.info("Successfully initialize the model with pretrained T5 model!")
|
||||
self.logger.info(f'Tokenizer additional specitial token ids {self.tokz.additional_special_tokens_ids}')
|
||||
else:
|
||||
if self.args.random_initialization:
|
||||
model = prev_model.fresh(curr_adapter_name, self.args, ema=False, initial=True)
|
||||
model.train()
|
||||
if self.args.meantc:
|
||||
ema_model = prev_ema_model.fresh(curr_adapter_name, self.args, ema=True, initial=True)
|
||||
ema_model = ema_model.cuda()
|
||||
ema_model.train()
|
||||
ema_num_tunable_params = sum(p.numel() for p in ema_model.parameters() if p.requires_grad)
|
||||
assert ema_num_tunable_params is 0
|
||||
else:
|
||||
model = prev_model
|
||||
curr_adapter_name = curr_task.replace('.','_')
|
||||
model.load_state_dict(prev_ema_model.state_dict())
|
||||
model.load_adapter(os.path.join(self.save_folder,last_adapter_name+'_adapter'), load_as=curr_adapter_name)
|
||||
model = model.fresh(curr_adapter_name, self.args, ema=False, initial=False)
|
||||
model.train()
|
||||
# ema_model = SSLLModel(self.config)
|
||||
ema_model = copy.deepcopy(prev_ema_model)
|
||||
ema_model.load_state_dict(prev_ema_model.state_dict())
|
||||
ema_model.load_adapter(os.path.join(self.save_folder,last_adapter_name+'_adapter'), load_as=curr_adapter_name)
|
||||
ema_model = ema_model.fresh(curr_adapter_name, self.args, ema=True, initial=False)
|
||||
ema_num_tunable_params = sum(p.numel() for p in ema_model.parameters() if p.requires_grad)
|
||||
assert ema_num_tunable_params is 0
|
||||
# ema_model.train()
|
||||
|
||||
if self.args.backward_augment:
|
||||
curr_unlabel_memory = Curr_Unlabel_Memory(args=self.args)
|
||||
|
||||
model = model.to(self.device, non_blocking=True)
|
||||
print('model.config',model.config,flush=True)
|
||||
|
||||
# ! Generate psueo data from previous model
|
||||
if self.args.gen_replay and len(prev_tasks)>=1:
|
||||
prev_train_dict = {}
|
||||
if self.args.use_unlabel:
|
||||
pseudo_data_count = int((len(self.datasets[curr_task]['label_train']) + len(self.datasets[curr_task]['unlabel_train'])) * self.args.pseudo_data_ratio) // len(prev_tasks)
|
||||
else:
|
||||
pseudo_data_count = int(len(self.datasets[curr_task]['label_train']) * self.args.pseudo_data_ratio) // len(prev_tasks)
|
||||
self.logger.info(f'Will generate {pseudo_data_count} pseudo samples of previous tasks.')
|
||||
|
||||
inferred_data = {}
|
||||
if self.args.construct_memory:
|
||||
memory = {}
|
||||
else: memory = None
|
||||
for prev_task in prev_tasks:
|
||||
pseudo_output_file = os.path.join(self.args.output_dir, curr_task+'_pseudo_'+prev_task+'.json')
|
||||
self.logger.info('Generated pseudo will be writen into '+str(pseudo_output_file))
|
||||
prev_data = gen_pseudo_data(prev_ema_model, prev_task, tokz=self.tokz, max_output_len=max_input_length_dict[prev_task],
|
||||
batch_size=4, target_count=pseudo_data_count,
|
||||
output_file=pseudo_output_file, args=self.args, question=question_dict[prev_task])
|
||||
|
||||
inferred_data[prev_task] = prev_data
|
||||
self.logger.info(f'Finish generating pseudo data for {len(prev_tasks)}.')
|
||||
self.logger.info(f'Inferring pseudo data from {prev_task}')
|
||||
for i in range(0, min(6, pseudo_data_count)):
|
||||
self.logger.info(f'PSEUDO: {prev_data[i]}')
|
||||
prev_dataset = PseudoDataset(prev_task, prev_data, tokz=self.tokz, max_input_len=self.args.max_input_len)
|
||||
prev_dataloader = DataLoader(prev_dataset, batch_size=self.args.train_batch_size, sampler=torch.utils.data.RandomSampler(prev_dataset),
|
||||
num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=prev_task))
|
||||
prev_train_dict[prev_task] = prev_dataloader
|
||||
|
||||
# * Construct memory for pseudo data.
|
||||
if self.args.construct_memory:
|
||||
for idx, prev_batch in enumerate(prev_dataloader):
|
||||
memory = construct_memory(prev_task, model, prev_batch, memory, args=self.args)
|
||||
# prev_train_dataset = PseudoDataset(inferred_data, tokz=self.tokz, max_input_len=self.args.max_input_len)
|
||||
|
||||
#! Zero-shot evaluation for new tasks
|
||||
if len(prev_tasks)>=1 and self.args.evaluate_zero_shot:
|
||||
self.logger.info(f'Evaluate the new task <<{curr_task}>> zero-shot with the previously learned model.')
|
||||
if self.args.select_adapter and self.args.construct_memory:
|
||||
test_keys_list = []
|
||||
for _ in range(3):
|
||||
curr_test_loader = self.data_loaders[curr_task]['test']
|
||||
curr_test_batch = next(iter(curr_test_loader))
|
||||
test_keys, _, _ = get_sentence_embedding(model, curr_test_batch, args=self.args) # keys: list of tensor
|
||||
test_keys_list += test_keys
|
||||
curr_center = torch.mean(torch.stack(test_keys_list), dim=0)
|
||||
old_center_dict = get_old_center_dict(memory, prev_tasks, args=self.args)
|
||||
initial_adapter_name = adapter_selection(prev_tasks, curr_center, old_center_dict, args=self.args)
|
||||
eval_adapter_name = initial_adapter_name.replace('.','_')
|
||||
self.logger.info(f'We select the old adapter {initial_adapter_name} to initialize the current task {curr_task}!')
|
||||
else:
|
||||
eval_adapter_name = curr_adapter_name
|
||||
initial_score = self._eval_score(model, curr_task, out_dir=os.path.join(self.args.output_dir, curr_task) if self.args.log_eval_res else None, last=True, use_adapter_name=eval_adapter_name)
|
||||
all_tasks_res.append({curr_task+'_initial':{'score':float(initial_score['score'])}})
|
||||
|
||||
if eval_adapter_name != curr_adapter_name:
|
||||
eval_ema_path = os.path.join(self.save_folder, str(initial_adapter_name)+'_ema_model'+'.pt')
|
||||
# model = torch.load(eval_ema_path)
|
||||
model.load_adapter(os.path.join(self.save_folder,eval_adapter_name+'_adapter'), load_as=curr_adapter_name)
|
||||
model = model.fresh(curr_adapter_name, self.args, ema=False, initial=False)
|
||||
model.train()
|
||||
# ema_model = SSLLModel(self.config)
|
||||
# ema_model = copy.deepcopy(prev_ema_model)
|
||||
# ema_model.load_state_dict(prev_ema_model.state_dict())
|
||||
# ema_model = torch.load(eval_ema_path)
|
||||
ema_model.load_adapter(os.path.join(self.save_folder,eval_adapter_name+'_adapter'), load_as=curr_adapter_name)
|
||||
ema_model = ema_model.fresh(curr_adapter_name, self.args, ema=True, initial=False)
|
||||
ema_num_tunable_params = sum(p.numel() for p in ema_model.parameters() if p.requires_grad)
|
||||
assert ema_num_tunable_params is 0
|
||||
self.logger.info(f'Reinitialize the model name~')
|
||||
|
||||
self.logger.info("Begin training!"+ "=" * 40)
|
||||
self.logger.info(f'Currently training {curr_task}'+'-'*10)
|
||||
|
||||
evalout_dir = os.path.join(self.args.output_dir, curr_task)
|
||||
if os.path.exists(evalout_dir) and os.path.isdir(evalout_dir):
|
||||
shutil.rmtree(evalout_dir)
|
||||
os.makedirs(evalout_dir, exist_ok=True)
|
||||
|
||||
# Optimizer weight decay
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': self.args.weight_decay},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
num_tunable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print('Total number of tunale parameters is %d'%num_tunable_params, flush=True)
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.lr, correct_bias=True)
|
||||
# optimizer = Adafactor(params=filter(lambda p: p.requires_grad, model.parameters()), lr=self.args.lr, weight_decay=self.args.weight_decay,
|
||||
# scale_parameter=False, relative_step=False)
|
||||
optimizer.zero_grad()
|
||||
if self.args.backward_augment:
|
||||
aug_optimizer =AdamW(optimizer_grouped_parameters, lr=self.args.lr, correct_bias=True)
|
||||
aug_optimizer.zero_grad()
|
||||
if self.args.noisy_stu or self.args.online_noisy_stu:
|
||||
stu_optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.lr, correct_bias=True)
|
||||
stu_optimizer.zero_grad()
|
||||
|
||||
# label_train_dataset = self.datasets[curr_task]['val']
|
||||
label_train_dataset = self.datasets[curr_task]['label_train']
|
||||
label_train_sampler = torch.utils.data.RandomSampler(label_train_dataset) # not distributed
|
||||
if self.args.use_unlabel:
|
||||
unlabel_train_dataset = self.datasets[curr_task]['unlabel_train']
|
||||
unlabel_train_sampler = torch.utils.data.RandomSampler(unlabel_train_dataset) # not distributed
|
||||
unlabel_batch_size = int(self.args.unlabel_amount) * self.args.train_batch_size
|
||||
unlabel_train_loader = DataLoader(unlabel_train_dataset, batch_size=unlabel_batch_size, sampler=unlabel_train_sampler,
|
||||
num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=curr_task))
|
||||
if self.args.add_pretrain_with_ft or self.args.pretrain_first:
|
||||
mix_train_file = os.path.join(self.args.data_dir, curr_task, 'pretrain.txt')
|
||||
if not os.path.exists(mix_train_file):
|
||||
_ = write_mix_train_file(label_train_dataset, unlabel_train_dataset, mix_train_file, os.path.join(self.args.data_dir, curr_task))
|
||||
pretrain_dataloader = create_dataloader_for_pretrain(mix_train_file, self.tokz, model, self.args)
|
||||
|
||||
# * Get train loader.
|
||||
label_train_loader = DataLoader(label_train_dataset, batch_size=self.args.train_batch_size, sampler=label_train_sampler,
|
||||
num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=curr_task))
|
||||
|
||||
t_total = len(label_train_loader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs[curr_task]
|
||||
if self.args.warmup_steps > 0:
|
||||
num_warmup_steps = self.args.warmup_steps
|
||||
else:
|
||||
num_warmup_steps = int(t_total * self.args.warmup_proportion)
|
||||
if not self.args.nouse_scheduler: # if True, not change lr.
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, t_total)
|
||||
|
||||
if self.args.pretrain_first and self.args.use_unlabel:
|
||||
self.logger.info('BEGIN Pretraining!'+'*'*10)
|
||||
for pre_epo in range(self.args.pretrain_epoch):
|
||||
PRE_ITER = enumerate(pretrain_dataloader)
|
||||
for k, predata in PRE_ITER:
|
||||
optimizer.zero_grad()
|
||||
pretrain_input_ids, pretrain_labels, pretrain_decoder_input_ids = get_model_inputs(predata, tokz=self.tokz, model_type='pretrain')
|
||||
pretrain_outputs = model(input_ids=pretrain_input_ids, labels=pretrain_labels, decoder_input_ids=pretrain_decoder_input_ids)
|
||||
pretrain_loss = pretrain_outputs.loss
|
||||
# print('pretrain loss',pretrain_outputs.loss.item(), flush=True)
|
||||
pretrain_loss.backward()
|
||||
for n, p in model.named_parameters():
|
||||
if curr_task in n:
|
||||
p.grad.data.zero_()
|
||||
optimizer.step()
|
||||
self.logger.info(f'Pretrain epoch: {pre_epo}, ' +' Pretrain loss: %.4f' % (pretrain_loss.item()))
|
||||
self.logger.info('Finish Pretraining!'+'*'*10)
|
||||
|
||||
# TODO: Training =====================================================================
|
||||
# ! Mix training unlabeled and labeled data.
|
||||
self.task_step = 0
|
||||
for epoch in range(1, self.args.num_train_epochs[curr_task]+1):
|
||||
save_steps = int(t_total // self.args.num_train_epochs[curr_task] * self.args.save_epochs)
|
||||
|
||||
if self.args.eval_times_per_task > 0:
|
||||
eval_steps = int(t_total / self.args.eval_times_per_task)
|
||||
else:
|
||||
eval_steps = self.args.eval_steps
|
||||
|
||||
if self.rank in [0, -1]:
|
||||
self.logger.info(
|
||||
f'num_warmup_steps: {num_warmup_steps}, t_total: {t_total}, save_steps: {save_steps}, eval_steps: {eval_steps}')
|
||||
self.logger.info("Total epoches: %d" % self.args.num_train_epochs[curr_task])
|
||||
self.logger.info('*'*50+"Epoch: %d" % epoch+'*'*50)
|
||||
st = time.time()
|
||||
|
||||
LITER = enumerate(label_train_loader)
|
||||
|
||||
# * Training labeled data.
|
||||
for i, data in LITER:
|
||||
if i==0 and epoch == 1:
|
||||
question = data['batch_text'][0]['question']
|
||||
print(f'For <<{curr_task}>>, the question is: "{question}"', flush=True)
|
||||
question_dict[curr_task] = question
|
||||
|
||||
if self.args.freeze_plm:
|
||||
curr_adapter_name = curr_task.replace('.','_')
|
||||
model.train_adapter(curr_adapter_name)
|
||||
if self.args.meantc:
|
||||
ema_model.set_active_adapters(curr_adapter_name)
|
||||
ema_num_tunable_params = sum(p.numel() for p in ema_model.parameters() if p.requires_grad)
|
||||
assert ema_num_tunable_params is 0
|
||||
else: # tune all params.
|
||||
curr_adapter_name = curr_task.replace('.','_')
|
||||
model.set_active_adapters(curr_adapter_name)
|
||||
model.freeze_model(False)
|
||||
if self.args.meantc:
|
||||
ema_model.set_active_adapters(curr_adapter_name)
|
||||
ema_model.freeze_model(False)
|
||||
|
||||
# * Get inputs of the LM model.
|
||||
cls_total = get_model_inputs(data, self.global_step, self.tokz)
|
||||
label_cls_loss = compute_solver_loss(model, self.loss_fn, cls_total, args=self.args)
|
||||
if self.args.add_label_lm_loss:
|
||||
lm_total = get_model_inputs(data, self.global_step, self.tokz, model_type='lm', task=curr_task)
|
||||
label_lm_loss = compute_lm_loss(model, self.loss_fn, lm_total, args=self.args)
|
||||
else: label_lm_loss = 0.0
|
||||
all_label_loss = label_cls_loss + self.args.lm_lambda * label_lm_loss
|
||||
all_label_loss.backward()
|
||||
|
||||
if self.args.add_pretrain_with_ft and self.args.use_unlabel:
|
||||
pretrain_data = next(iter(pretrain_dataloader))
|
||||
pretrain_input_ids, pretrain_labels, pretrain_decoder_input_ids = get_model_inputs(data, self.global_step, self.tokz, model_type='pretrain')
|
||||
pretrain_outputs = model(input_ids=pretrain_input_ids, labels=pretrain_labels, decoder_input_ids=pretrain_decoder_input_ids)
|
||||
print('pretrain loss',pretrain_outputs.loss.item(), flush=True)
|
||||
|
||||
assert self.args.add_pretrain_with_ft and self.args.use_unlabel
|
||||
pretrain_loss.backward()
|
||||
# * Only update backbone model parameters that are shared across tasks.
|
||||
for n, p in model.named_parameters():
|
||||
# if 'adapter'+curr_task in n:
|
||||
if curr_task in n:
|
||||
p.grad.data.zero_()
|
||||
# elif 'lm_head'+curr_task in n:
|
||||
# p.grad.data.zero_()
|
||||
pretrain_loss = 0.1 * pretrain_outputs.loss
|
||||
else: pretrain_loss = 0.0
|
||||
|
||||
if not self.args.debug_use_unlabel:
|
||||
if self.args.stu_feedback and self.args.meantc:
|
||||
fb_loss = 0.1 * compute_feedback_loss(model, ema_model, self.loss_fn, cls_total, args=self.args, kl_loss=self.kl_loss)
|
||||
# fb_loss = compute_feedback_loss(model, ema_model, self.loss_fn, cls_total, args=self.args, kl_loss=self.kl_loss)
|
||||
# if fb_loss <=1e-1: use_unlabel=True
|
||||
if abs(fb_loss-label_cls_loss)<=self.args.feedback_threshold: use_unlabel=True
|
||||
else: use_unlabel=False
|
||||
elif epoch >= self.args.warmup_epoch:
|
||||
use_unlabel=True
|
||||
else: use_unlabel=False
|
||||
else:
|
||||
use_unlabel=True # Just test the memory burst situation.
|
||||
|
||||
if self.args.use_unlabel:
|
||||
if self.args.unlabel_amount=='all':
|
||||
unlabel_iter = len(unlabel_train_loader) // len(label_train_loader)
|
||||
else:
|
||||
unlabel_iter = int(self.args.unlabel_amount) # How many batches of unlabeled data used per label batch.
|
||||
|
||||
# for j in range(1):
|
||||
for j in range(unlabel_iter):
|
||||
|
||||
unlabel_data = next(iter(unlabel_train_loader))
|
||||
|
||||
if self.args.add_unlabel_lm_loss:
|
||||
unlabel_lm_total = get_model_inputs(unlabel_data, self.global_step, self.tokz, model_type='lm', task=curr_task)
|
||||
unlabel_lm_loss = (1/unlabel_iter) * self.args.lm_lambda * compute_lm_loss(model, self.loss_fn, unlabel_lm_total, args=self.args)
|
||||
unlabel_lm_loss.backward()
|
||||
else: unlabel_lm_loss = 0.0
|
||||
|
||||
#* Use the current task unlabeled data to augment the memory for old tasks.
|
||||
if self.args.backward_augment:
|
||||
curr_keys, curr_values, curr_questions = get_sentence_embedding(model, unlabel_data, args=self.args)
|
||||
curr_unlabel_memory.push(curr_keys, curr_values, curr_questions)
|
||||
model.train_adapter(curr_adapter_name)
|
||||
|
||||
#! Use the unlabeled data for consistency regularization.
|
||||
if use_unlabel:
|
||||
if self.args.input_aug:
|
||||
unlabeled_batch_text = unlabel_data['batch_text']
|
||||
aug_unlabel_data = aug_unlabeled_data(unlabeled_batch_text, self.tokz, self.args, curr_task)
|
||||
un_cls_total = get_model_inputs(aug_unlabel_data, self.global_step, self.tokz, data_type='unlabeled')
|
||||
else:
|
||||
un_cls_total = get_model_inputs(unlabel_data, self.global_step, self.tokz, data_type='unlabeled')
|
||||
consistency_loss = (1 / unlabel_iter) * self.args.ungamma * compute_consistency_loss(model, ema_model, self.loss_fn, un_cls_total, args=self.args, kl_loss=self.kl_loss, tokz=self.tokz, epoch=epoch, task_name=curr_task)
|
||||
consistency_loss.backward()
|
||||
|
||||
#! Augment the current unlabeled data with pseudo unlabeled data retrieved from the memory.
|
||||
if self.args.construct_memory and len(prev_tasks)>=1 and self.args.forward_augment:
|
||||
assert memory is not None
|
||||
all_tasks_neighbors = []
|
||||
for prev_task in prev_tasks:
|
||||
querys, _, query_questions = get_sentence_embedding(model, unlabel_data, args=self.args) # For current task
|
||||
# Get neighbors with the corresponding questions.
|
||||
neighbors = get_neighbors(querys, prev_task, self.args.kneighbors, memory=memory, args=self.args, questions=query_questions)
|
||||
model.train_adapter(curr_adapter_name)
|
||||
all_tasks_neighbors += neighbors
|
||||
|
||||
if len(all_tasks_neighbors)>0:
|
||||
print(f'The constructed memory contains "{len(all_tasks_neighbors)}" neighbors to augment.',flush=True)
|
||||
pseudo_unlabel_batch = create_batch_from_memory(all_tasks_neighbors, self.tokz, self.args, curr_task)
|
||||
if pseudo_unlabel_batch is not None:
|
||||
pseudo_dataset = MemoryDataset(pseudo_unlabel_batch)
|
||||
pseudo_sampler = torch.utils.data.RandomSampler(pseudo_dataset)
|
||||
pseudo_loader = DataLoader(pseudo_dataset, batch_size=self.args.train_batch_size, sampler=pseudo_sampler, num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=None))
|
||||
for _, pseudo_batch in enumerate(pseudo_loader):
|
||||
# self.logger.info('Finish creating batch with the retrieved pseudo unlabeled data from previously learned tasks.')
|
||||
pseudo_un_cls_total = get_model_inputs(pseudo_batch, self.global_step, self.tokz, data_type='unlabeled')
|
||||
consistency_loss_from_memory = (1/unlabel_iter) * self.args.ungamma * compute_consistency_loss(model, ema_model, self.loss_fn, pseudo_un_cls_total, args=self.args, kl_loss=self.kl_loss, tokz=self.tokz, epoch=epoch, task_name=curr_task)
|
||||
consistency_loss_from_memory.backward()
|
||||
print(f'forward memory loss: {consistency_loss_from_memory.item()}')
|
||||
|
||||
if (i + 1) % self.args.gradient_accumulation_steps == 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
if not self.args.nouse_scheduler:
|
||||
scheduler.step() # Can change the lr.
|
||||
if self.args.meantc:
|
||||
# update_ema_variables(model, ema_model, self.args.ema_decay, self.global_step)
|
||||
update_ema_variables(model, ema_model, self.args.ema_decay, self.task_step)
|
||||
optimizer.zero_grad()
|
||||
self.task_step += 1
|
||||
|
||||
# Update the learning rate.
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
|
||||
# Write in the logger.
|
||||
self.logger.info(f'EPOCH: {epoch}, task step: {self.task_step}, global step: {self.global_step} ' + 'Training loss: %.4f' % (label_cls_loss) + ' LM loss: %.4f' % (label_lm_loss))
|
||||
|
||||
# Log info to Tensorboard
|
||||
self.train_writer.add_scalar('train_label_cls_loss', label_cls_loss, self.global_step)
|
||||
self.train_writer.add_scalar('train_label_lm_loss', label_lm_loss, self.global_step)
|
||||
self.train_writer.add_scalar('time_per_batch', time.time() - st, self.global_step)
|
||||
self.train_writer.add_scalar('lr', lr, self.global_step)
|
||||
st = time.time()
|
||||
|
||||
self.global_step += 1
|
||||
|
||||
# TODO: Evaluate model on Valid datasets. ================================================================================================
|
||||
if self.global_step % self.args.eval_steps == 0 and self.args.eval_steps > 0:
|
||||
all_res_per_eval = {}
|
||||
avg_metric_loss, avg_metric = [], []
|
||||
tc_avg_metric = []
|
||||
model.eval()
|
||||
for val_task in self.args.tasks[:len(prev_tasks)+1]:
|
||||
eval_loss = self._eval(model, val_task, self.loss_fn)
|
||||
eval_metric = self._eval_score(model, val_task, out_dir=os.path.join(self.args.output_dir, curr_task) if self.args.log_eval_res else None)
|
||||
if self.args.meantc:
|
||||
tc_eval_metric = self._eval_score(ema_model, val_task, out_dir=os.path.join(self.args.output_dir, curr_task+'_teacher') if self.args.log_eval_res else None)
|
||||
|
||||
avg_metric_loss.append(float(eval_loss['loss']))
|
||||
avg_metric.append(float(eval_metric['score']))
|
||||
if self.args.meantc:
|
||||
tc_avg_metric.append(float(tc_eval_metric['score']))
|
||||
|
||||
for key, value in eval_loss.items():
|
||||
self.valid_writer.add_scalar(f'{val_task}/{key}', value, self.global_step)
|
||||
for key, value in eval_metric.items():
|
||||
self.valid_writer.add_scalar(f'{val_task}/{key}', value, self.global_step)
|
||||
if self.args.meantc:
|
||||
for key, value in tc_eval_metric.items():
|
||||
self.valid_writer.add_scalar(f'{val_task}/tc_{key}', value, self.global_step)
|
||||
self.logger.info('Evaluation! '+f'Current task: {curr_task}, task_epoch: {epoch}, task step: {self.task_step}, ' + f'global step: {self.global_step}, {val_task}: teacher eval metric: {tc_eval_metric}')
|
||||
|
||||
self.logger.info('Evaluation! '+f'Current task: {curr_task}, task_epoch: {epoch}, task step: {self.task_step}, ' + f'global step: {self.global_step}, {val_task}: eval metric: {eval_metric}')
|
||||
|
||||
all_res_per_eval[val_task] = {'score': eval_metric['score']}
|
||||
# print(all_res_per_eval,flush=True)
|
||||
all_tasks_res.append(all_res_per_eval)
|
||||
self.valid_writer.add_scalar(f'avg/score', sum(avg_metric)/len(avg_metric), self.global_step)
|
||||
if self.args.meantc:
|
||||
self.valid_writer.add_scalar(f'avg/tc_score',sum(tc_avg_metric)/len(tc_avg_metric), self.global_step)
|
||||
model.train()
|
||||
# if not self.args.nouse_scheduler:
|
||||
# scheduler.step() # Can change the lr.
|
||||
self.logger.info("Training loop. The %dth epoch completed."%epoch)
|
||||
|
||||
if self.args.backward_augment and len(prev_tasks)>0:
|
||||
aug_optimizer =AdamW(optimizer_grouped_parameters, lr=self.args.lr, correct_bias=True)
|
||||
for g in aug_optimizer.param_groups:
|
||||
g['lr'] = 5e-6
|
||||
# g['lr'] = 1e-6
|
||||
aug_optimizer.zero_grad()
|
||||
self.logger.info(f'Begin backward augmentation training!')
|
||||
for old_task in self.args.tasks[:len(prev_tasks)]:
|
||||
old_adapter_name = old_task.replace('.','_')
|
||||
model.set_active_adapters(old_adapter_name)
|
||||
ema_model.set_active_adapters(old_adapter_name)
|
||||
model.train_adapter(old_adapter_name)
|
||||
|
||||
old_memory = memory[old_task]
|
||||
aug_old_unlabel_batch = create_batch_to_augment_memory(old_task, old_memory, curr_unlabel_memory, tokz=self.tokz, args=self.args)
|
||||
if aug_old_unlabel_batch is not None:
|
||||
aug_dataset = MemoryDataset(aug_old_unlabel_batch)
|
||||
aug_sampler = torch.utils.data.RandomSampler(aug_dataset) # not distributed
|
||||
aug_loader = DataLoader(aug_dataset, batch_size=self.args.train_batch_size, sampler=aug_sampler,
|
||||
num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=None))
|
||||
for _, aug_batch in enumerate(aug_loader):
|
||||
# for i in range(2):
|
||||
aug_old_untotal = get_model_inputs(aug_batch, self.global_step, self.tokz, data_type='unlabeled')
|
||||
backaug_consistency_loss = self.args.ungamma * compute_consistency_loss(model, ema_model, self.loss_fn, aug_old_untotal, args=self.args, kl_loss=self.kl_loss, tokz=self.tokz, epoch=i, task_name=old_task)
|
||||
backaug_consistency_loss.backward()
|
||||
print(f'backward augment loss {backaug_consistency_loss.item()}')
|
||||
aug_optimizer.step()
|
||||
aug_optimizer.zero_grad()
|
||||
update_ema_variables(model, ema_model, self.args.ema_decay, 1e5)
|
||||
|
||||
# * Test all learned tasks after training the certain epochs.
|
||||
if self.args.test_all:
|
||||
last_avg_metric = []
|
||||
last_tc_avg_metric = []
|
||||
curr_task_score = {}
|
||||
prev_tasks_scores = []
|
||||
for val_task in self.args.tasks[:len(prev_tasks)+1]:
|
||||
last_eval_metric = self._eval_score(model, val_task, out_dir=os.path.join(self.args.output_dir, curr_task) if self.args.log_eval_res else None, last=True)
|
||||
last_tc_eval_metric = self._eval_score(ema_model, val_task, out_dir=os.path.join(self.args.output_dir, curr_task+'_teacher') if self.args.log_eval_res else None, last=True)
|
||||
last_avg_metric.append(float(last_eval_metric['score']))
|
||||
last_tc_avg_metric.append(float(last_tc_eval_metric['score']))
|
||||
if val_task == curr_task:
|
||||
curr_task_score['score'] = float(last_eval_metric['score'])
|
||||
curr_task_score['tc_score'] = float(last_tc_eval_metric['score'])
|
||||
else:
|
||||
prev_tasks_scores.append({val_task:{'score':float(last_eval_metric['score']),'tc_score':float(last_tc_eval_metric['score'])}})
|
||||
all_tasks_res.append({curr_task+'_finish':curr_task_score})
|
||||
# Average scores of previous learned task except the current learned task.
|
||||
if len(prev_tasks)>=1:
|
||||
avg_old_stu_score = sum(last_avg_metric[:-1])/len(last_avg_metric[:-1])
|
||||
avg_old_tc_score = sum(last_tc_avg_metric[:-1])/len(last_tc_avg_metric[:-1])
|
||||
average_old_score = {'score': avg_old_stu_score,'tc_score':avg_old_tc_score}
|
||||
all_tasks_res.append({'prev_tasks_scores':prev_tasks_scores})
|
||||
all_tasks_res.append({'Average_wo_curr_task':average_old_score})
|
||||
# Average scores of all learned tasks.
|
||||
avg_stu_score = sum(last_avg_metric)/len(last_avg_metric)
|
||||
avg_tc_score = sum(last_tc_avg_metric)/len(last_tc_avg_metric)
|
||||
self.valid_writer.add_scalar(f'avg_alldata/score', avg_stu_score, self.global_step)
|
||||
self.logger.info(f'Last Evaluation! Average score: {avg_stu_score}; Average teacher score: {avg_tc_score}'+'*'*10)
|
||||
average_score = {'score': avg_stu_score,'tc_score': avg_tc_score}
|
||||
all_tasks_res.append({'Average':average_score})
|
||||
|
||||
# TODO: Model saving. ------------------------------------------------------------------------
|
||||
task_final_save_path = os.path.join(self.save_folder, str(curr_task)+'_model'+'.pt')
|
||||
self.logger.info('Saving model checkpoint of task %s to %s'%(curr_task, task_final_save_path))
|
||||
torch.save(model.state_dict(), task_final_save_path)
|
||||
|
||||
model_save_path = os.path.join(self.save_folder, 'model_last.pt')
|
||||
torch.save(model.state_dict(), model_save_path)
|
||||
if self.args.meantc:
|
||||
ema_task_save_path = os.path.join(self.save_folder, str(curr_task)+'_ema_model'+'.pt')
|
||||
ema_all_save_path = os.path.join(self.save_folder, 'ema_model_last.pt')
|
||||
torch.save(ema_model.state_dict(), ema_task_save_path)
|
||||
torch.save(ema_model.state_dict(), ema_all_save_path)
|
||||
ema_model.save_adapter(os.path.join(self.save_folder,curr_adapter_name+'_adapter'), curr_adapter_name) # save teacher adapter.
|
||||
|
||||
self.logger.info('Saving total model checkpoint to %s', model_save_path)
|
||||
self.logger.info('='*25+'Training complete!'+'='*25)
|
||||
|
||||
return model, ema_model, all_tasks_res, question_dict
|
||||
|
||||
def _eval(self, model, task, loss_fn):
|
||||
lm_recorder = MetricsRecorder(self.device, 'loss')
|
||||
|
||||
with torch.no_grad():
|
||||
adapter_name = task.replace('.','_')
|
||||
model.set_active_adapters(adapter_name)
|
||||
model.eval()
|
||||
if self.args.test_overfit:
|
||||
test_dataset = self.data_loaders[task]['label_train']
|
||||
else:
|
||||
test_dataset = self.data_loaders[task]['test']
|
||||
|
||||
for i, data in enumerate(test_dataset):
|
||||
label_total = get_model_inputs(data, self.global_step, self.tokz)
|
||||
# label_lm_loss = compute_lm_loss(self.device, model, self.loss_fn, label_total, has_label=True, args=self.args, kl_loss=self.kl_loss)
|
||||
loss = compute_solver_loss(model, self.loss_fn, label_total, args=self.args)
|
||||
# loss = compute_lm_loss(self.device, model, loss_fn, lm_total, has_label=True, args=self.args)
|
||||
lm_recorder.metric_update(i, {'loss': loss})
|
||||
if i >= 35:
|
||||
break
|
||||
|
||||
if self.rank != -1:
|
||||
lm_recorder.all_reduce()
|
||||
return lm_recorder
|
||||
|
||||
# * Calculate accuracy of the model
|
||||
def _eval_score(self, model, task, out_dir=None,last=False, use_adapter_name=None):
|
||||
max_ans_len = max(self.datasets[task]['label_train'].max_ans_len, self.datasets[task]
|
||||
['val'].max_ans_len, self.datasets[task]['test'].max_ans_len) + 1
|
||||
# print('Maximum answer length',max_ans_len,flush=True)
|
||||
assert task is not None
|
||||
max_ans_len = max_ans_length_dict[task]
|
||||
if out_dir is not None:
|
||||
if not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir)
|
||||
out_dir = os.path.join(out_dir, f'{task}_step{self.global_step}.json')
|
||||
out_dir_content = []
|
||||
|
||||
with torch.no_grad():
|
||||
if use_adapter_name:
|
||||
model.set_active_adapters(use_adapter_name)
|
||||
else:
|
||||
adapter_name = task.replace('.','_')
|
||||
model.set_active_adapters(adapter_name)
|
||||
# print('model of task: {}'.format(task),list(model.parameters()), flush=True)
|
||||
model.eval()
|
||||
pred_ans_all, gold_ans_all = [], []
|
||||
context_all = []
|
||||
all_pred_task_name = []
|
||||
if task in ['wikisql','woz.en', 'squad']:
|
||||
all_gold_list = []
|
||||
|
||||
if self.args.test_overfit:
|
||||
test_dataloader = self.data_loaders[task]['label_train']
|
||||
else:
|
||||
test_dataloader = self.data_loaders[task]['test']
|
||||
|
||||
for i, data in enumerate(test_dataloader):
|
||||
example = data['context_id'].cuda()
|
||||
example_mask = data['context_mask'].cuda()
|
||||
if out_dir is not None:
|
||||
context_all.append(example)
|
||||
gold_ans = data['ans_id'].cuda()
|
||||
|
||||
# Print model inputs to check.
|
||||
# if i==0:
|
||||
# print('context:',self.tokz.decode(example[0]),flush=True)
|
||||
# print('ans:',self.tokz.decode(gold_ans[0]),flush=True)
|
||||
|
||||
pred_ans, _ = get_answer(self.tokz, model, example, example_mask, max_ans_len, args=self.args, is_eval=True)
|
||||
# print('context shape',context.shape,flush=True)
|
||||
# pred_ans = pred_ans[:,context.shape[1]:]
|
||||
pred_ans = pred_ans[:,1:] # Remove the first <pad> token.
|
||||
# print('pred',pred_ans,flush=True)
|
||||
# print('gold',gold_ans,flush=True) # output tensors
|
||||
|
||||
pred_ans_all.append(pred_ans)
|
||||
gold_ans_all.append(gold_ans)
|
||||
|
||||
if task in ['wikisql','woz.en', 'squad']:
|
||||
for k in range(len(data['context_id'])):
|
||||
gold_ans_list = data['batch_text'][k]['ans_list']
|
||||
all_gold_list.append(gold_ans_list)
|
||||
|
||||
# if i>=1: break
|
||||
if not last and task not in ['wikisql','woz.en','squad']:
|
||||
if self.args.train_batch_size==64:
|
||||
if i >= 35: break
|
||||
elif self.args.train_batch_size==16:
|
||||
if i >= 35: break
|
||||
|
||||
# collect results from different nodes
|
||||
pred_ans_all = communicate_tensor(pred_ans_all, pad_token=self.tokz.pad_token_id).tolist()
|
||||
gold_ans_all = communicate_tensor(gold_ans_all, pad_token=self.tokz.pad_token_id).tolist()
|
||||
if out_dir is not None:
|
||||
context_all = communicate_tensor(context_all, pad_token=self.tokz.pad_token_id).tolist()
|
||||
|
||||
correct_count, total_count = 0, len(pred_ans_all)
|
||||
qa_results = []
|
||||
pred_list = []
|
||||
print(f'Total number of evaluate data is {total_count}',flush=True)
|
||||
|
||||
for i in range(total_count):
|
||||
if out_dir is not None:
|
||||
res = {}
|
||||
# print('context_all',context_all[i],flush=True)
|
||||
context = self.tokz.decode(strip_list(context_all[i], self.tokz.eos_token_id), skip_special_tokens=True)
|
||||
res['context'] = context
|
||||
pred = self.tokz.decode(cut_eos(pred_ans_all[i], self.tokz.eos_token_id), skip_special_tokens=True)
|
||||
res['ans_pred'] = pred
|
||||
pred_list.append(pred)
|
||||
|
||||
#* Get ground truth
|
||||
if task=='woz.en':
|
||||
# print(self.datasets[task]['test'].answers)
|
||||
gold = self.datasets[task]['test'].answers[i]
|
||||
res['ans_gold'] = gold[1]
|
||||
elif task == 'wikisql':
|
||||
gold = self.datasets[task]['test'].answers[i]
|
||||
res['ans_gold'] = gold['answer']
|
||||
elif task == 'squad':
|
||||
gold = all_gold_list[i]
|
||||
assert type(gold) == list
|
||||
qa_results.append([pred, gold])
|
||||
res['ans_gold'] = gold[0]
|
||||
else:
|
||||
gold = self.tokz.decode(cut_eos(gold_ans_all[i], self.tokz.eos_token_id), skip_special_tokens=True)
|
||||
qa_results.append([pred,[gold]])
|
||||
res['ans_gold'] = gold
|
||||
out_dir_content.append(res)
|
||||
|
||||
if task in ['wikisql','woz.en']:
|
||||
test_dataset = self.datasets[task]['test']
|
||||
ids = test_dataset.get_indices()
|
||||
test_dataset.sort_by_index()
|
||||
qa_results = [x[1] for x in sorted([(i,g) for i, g in zip(ids, pred_list)])]
|
||||
|
||||
for i in range(len(test_dataset)):
|
||||
Y = test_dataset.answers[i]
|
||||
qa_results[i] = [qa_results[i], Y]
|
||||
|
||||
score_dict = compute_metrics(qa_results,
|
||||
bleu='iwslt.en.de' in task or 'multinli.in.out' in task,
|
||||
dialogue='woz.en' in task,
|
||||
rouge='cnn_dailymail' in task,
|
||||
logical_form='wikisql' in task,
|
||||
corpus_f1='zre' in task )
|
||||
self.logger.info(f'Score dictionary for task {task}>> {score_dict}')
|
||||
metric = task2metric_dict[task]
|
||||
score = score_dict[metric]
|
||||
|
||||
if out_dir is not None:
|
||||
with open(out_dir, 'w', encoding='utf-8') as outfile:
|
||||
for res in out_dir_content:
|
||||
print(json.dumps(res), file=outfile)
|
||||
model.train()
|
||||
return {'score': '{:.6f}'.format(score)}
|
||||
|
||||
# * Train across tasks through time.
|
||||
def train(self, tasks):
|
||||
model = None
|
||||
ema_model = None
|
||||
self.logger.info('Total tasks number is '+str(len(tasks)))
|
||||
|
||||
all_tasks_res = []
|
||||
res_file = os.path.join(self.save_folder, 'metrics.json')
|
||||
question_dict = {}
|
||||
|
||||
for i in range(len(tasks)):
|
||||
model, ema_model, all_tasks_res, question_dict = self._train(model, ema_model, curr_task=tasks[i], prev_tasks=tasks[:i], all_tasks_res=all_tasks_res, question_dict=question_dict)
|
||||
|
||||
self.logger.info('We have trained %d tasks'%(i+1))
|
||||
if self.args.use_memory:
|
||||
self.logger.info('Memory has saved information of %d tasks'%(len(memory.memory.keys()))+'-'*10)
|
||||
|
||||
# * Evaluation metrics saving.
|
||||
with open(res_file,'w', encoding='utf-8') as f:
|
||||
for e in all_tasks_res:
|
||||
print(json.dumps(e,ensure_ascii=False), file=f)
|
||||
|
|
@ -0,0 +1,358 @@
|
|||
from unifymodel.dataset import PadBatchSeq, pad_seq, get_unlabel_data
|
||||
import logging
|
||||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from sklearn.metrics import accuracy_score, f1_score
|
||||
import torch.distributed as dist
|
||||
import os, time, gc, json, pickle, argparse, math, re
|
||||
import torch.nn as nn
|
||||
import torch.utils.data as data
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import copy
|
||||
from transformers import GPT2Config, GPT2Model, GPT2LMHeadModel, GPT2Tokenizer
|
||||
import torch.nn.functional as F
|
||||
from tools.eda import *
|
||||
|
||||
loggers = {}
|
||||
def get_logger(filename, level=logging.INFO, print2screen=True):
|
||||
global loggers
|
||||
import logging
|
||||
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
|
||||
if loggers.get(filename):
|
||||
return loggers.get(filename)
|
||||
else:
|
||||
logger = logging.getLogger(filename)
|
||||
logger.setLevel(level)
|
||||
fh = logging.FileHandler(filename, encoding='utf-8')
|
||||
fh.setLevel(level)
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(level)
|
||||
formatter = logging.Formatter('[%(asctime)s][%(filename)s][line: %(lineno)d][%(levelname)s] >> %(message)s')
|
||||
# formatter = logging.Formatter('[%(asctime)s][%(thread)d][%(filename)s][line: %(lineno)d][%(levelname)s] >> %(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
if print2screen:
|
||||
logger.addHandler(ch)
|
||||
loggers[filename] = logger
|
||||
return logger
|
||||
|
||||
def frange_cycle_linear(n_iter, start=0.01, stop=1.0, n_cycle=4, ratio=0.5):
|
||||
L = np.ones(n_iter) * stop
|
||||
period = n_iter/n_cycle
|
||||
step = (stop-start)/(period*ratio) # linear schedule
|
||||
|
||||
for c in range(n_cycle):
|
||||
v, i = start, 0
|
||||
while v <= stop and (int(i+c*period) < n_iter):
|
||||
L[int(i+c*period)] = v
|
||||
v += step
|
||||
i += 1
|
||||
return L
|
||||
|
||||
def num_params(model):
|
||||
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
|
||||
|
||||
def switch_schedule(schedule, mult, switch):
|
||||
""" Apply LR multiplier before iteration "switch" """
|
||||
def f(e):
|
||||
s = schedule(e)
|
||||
if e < switch:
|
||||
return s * mult
|
||||
return s
|
||||
|
||||
return f
|
||||
|
||||
def linear_schedule(args):
|
||||
def f(e):
|
||||
if e <= args.warmup:
|
||||
return e / args.warmup
|
||||
return max((e - args.iterations) / (args.warmup - args.iterations), 0)
|
||||
|
||||
return f
|
||||
|
||||
def get_mask(x_len):
|
||||
mask = torch.arange(max(x_len), device=x_len.device)[None, :] < x_len[:, None] # [bs, max_len]
|
||||
return mask.bool()
|
||||
|
||||
def get_reverse_mask(x_len):
|
||||
mask = torch.arange(max(x_len)-1, -1, -1, device=x_len.device)[None, :] < x_len[:, None] # [bs, max_len]
|
||||
return mask.bool()
|
||||
|
||||
def compare_tokens(x, y, eos_id):
|
||||
if eos_id in x:
|
||||
x = x[:x.index(eos_id)]
|
||||
if eos_id in y:
|
||||
y = y[:y.index(eos_id)]
|
||||
return x == y
|
||||
|
||||
def pad_tensor(tensor, length, pad_token=0):
|
||||
return torch.cat([tensor, tensor.new(tensor.size(0), length - tensor.size()[1]).fill_(pad_token)], dim=1)
|
||||
|
||||
def seed_everything(seed=42):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def communicate_tensor(tensor_list, pad_token=0):
|
||||
'''
|
||||
collect tensors from all processes
|
||||
'''
|
||||
if len(tensor_list) == 0:
|
||||
return None
|
||||
device = tensor_list[0].device
|
||||
max_len = torch.tensor(max([i.shape[1] for i in tensor_list]), dtype=torch.int64, device=device)
|
||||
if dist.is_initialized(): # Obtain the max_len of the second dim of each tensor
|
||||
dist.all_reduce(max_len, op=dist.ReduceOp.MAX)
|
||||
# Pad tensors to the max_len
|
||||
tensor = torch.cat([pad_tensor(i, max_len, pad_token) for i in tensor_list], dim=0)
|
||||
tensor_bs = torch.tensor(tensor.shape[0], dtype=torch.int64, device=device)
|
||||
max_tensor_bs = torch.tensor(tensor.shape[0], dtype=torch.int64, device=device)
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(max_tensor_bs, op=dist.ReduceOp.MAX) # Obtain the max_tensor_bs of each tensor
|
||||
if max_tensor_bs != tensor_bs:
|
||||
tensor = torch.cat([tensor, tensor.new(max_tensor_bs-tensor_bs, tensor.shape[1]).fill_(pad_token)], dim=0)
|
||||
|
||||
# Gather padded tensors and the bs of each tensor
|
||||
tensor_list = [torch.ones_like(tensor).fill_(pad_token) for _ in range(dist.get_world_size())]
|
||||
tensor_bs_list = [torch.ones_like(tensor_bs).fill_(pad_token) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(tensor_list=tensor_list, tensor=tensor.contiguous())
|
||||
dist.all_gather(tensor_list=tensor_bs_list, tensor=tensor_bs)
|
||||
# Cut the padded batch
|
||||
for i in range(dist.get_world_size()):
|
||||
tensor_list[i] = tensor_list[i][:tensor_bs_list[i]]
|
||||
tensor = torch.cat(tensor_list, dim=0)
|
||||
return tensor
|
||||
|
||||
class MetricsRecorder(object):
|
||||
def __init__(self, device, *metric_names):
|
||||
self.metric_names = list(metric_names)
|
||||
self.device = device
|
||||
self.metrics = {}
|
||||
for metric_name in metric_names:
|
||||
self.metrics[metric_name] = torch.tensor(0, dtype=torch.float64, device=self.device)
|
||||
|
||||
def metric_update(self, batch_no, metric_values_dict):
|
||||
for k, v in metric_values_dict.items():
|
||||
self.metrics[k] = (self.metrics[k] * batch_no + v) / (batch_no + 1)
|
||||
|
||||
def add_to_writer(self, writer, step, group):
|
||||
for n in self.metric_names:
|
||||
m = self.metrics[n].item()
|
||||
writer.add_scalar('%s/%s' % (group, n), m, step)
|
||||
|
||||
def write_to_logger(self, logger, epoch, step):
|
||||
log_str = 'epoch {:>3}, step {}'.format(epoch, step)
|
||||
for n in self.metric_names:
|
||||
m = self.metrics[n].item()
|
||||
log_str += ', %s %g' % (n, m)
|
||||
logger.info(log_str)
|
||||
|
||||
def items(self):
|
||||
return self.metrics.items()
|
||||
|
||||
def all_reduce(self):
|
||||
for n in self.metric_names:
|
||||
torch.distributed.all_reduce(self.metrics[n], op=torch.distributed.ReduceOp.SUM)
|
||||
self.metrics[n] /= torch.distributed.get_world_size()
|
||||
|
||||
def __getitem__(self, k):
|
||||
return self.metrics[k]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.metrics[key] = value
|
||||
|
||||
def __repr__(self):
|
||||
return self.metrics.__repr__()
|
||||
|
||||
def __str__(self):
|
||||
return self.metrics.__str__()
|
||||
|
||||
def keys(self):
|
||||
return self.metrics.keys()
|
||||
|
||||
def cut_eos(seq, eos_id):
|
||||
if eos_id not in seq:
|
||||
return seq
|
||||
return seq[:seq.index(eos_id)]
|
||||
|
||||
def cal_metrics_from_pred_files(res_file):
|
||||
with open(res_file, 'r', encoding='utf-8') as f:
|
||||
res = [json.loads(i) for i in f.readlines()]
|
||||
y_true = [i['ans_gold'] for i in res]
|
||||
y_pred = [i['ans_pred'] for i in res]
|
||||
return {
|
||||
"accuracy": accuracy_score(y_true, y_pred),
|
||||
}
|
||||
|
||||
|
||||
def slot_f1_score(pred_slots, true_slots):
|
||||
'''
|
||||
pred_slots, true_slots are like [['from_location:10-11', 'leaving_date:12-13']]
|
||||
'''
|
||||
slot_types = set([slot.split(":")[0] for row in true_slots for slot in row])
|
||||
slot_type_f1_scores = []
|
||||
|
||||
for slot_type in slot_types:
|
||||
predictions_for_slot = [[p for p in prediction if slot_type in p] for prediction in pred_slots] # [['from_location'],[],[],['from_location']]
|
||||
labels_for_slot = [[l for l in label if slot_type in l] for label in true_slots]
|
||||
|
||||
proposal_made = [len(p) > 0 for p in predictions_for_slot]
|
||||
has_label = [len(l) > 0 for l in labels_for_slot]
|
||||
prediction_correct = [prediction == label for prediction, label in zip(predictions_for_slot, labels_for_slot)]
|
||||
true_positives = sum([
|
||||
int(proposed and correct)
|
||||
for proposed, correct in zip(proposal_made, prediction_correct)])
|
||||
|
||||
num_predicted = sum([int(proposed) for proposed in proposal_made])
|
||||
num_to_recall = sum([int(hl) for hl in has_label])
|
||||
|
||||
precision = true_positives / (1e-5 + num_predicted)
|
||||
recall = true_positives / (1e-5 + num_to_recall)
|
||||
|
||||
f1_score = 2 * precision * recall / (1e-5 + precision + recall)
|
||||
slot_type_f1_scores.append(f1_score)
|
||||
|
||||
return np.mean(slot_type_f1_scores)
|
||||
|
||||
def textid_decode(text, eos, tokz, contain_eos=False):
|
||||
if eos in text:
|
||||
if contain_eos:
|
||||
idx = text.index(eos)
|
||||
text = text[:idx]
|
||||
else:
|
||||
text = text[:text.index(eos)]
|
||||
text_id = text
|
||||
len_text = len(text)
|
||||
text = tokz.decode(text).strip()
|
||||
return text, len_text, text_id
|
||||
|
||||
def padding_convert(text_list, eos):
|
||||
tt_list = []
|
||||
for text in text_list:
|
||||
if eos in text:
|
||||
eos_indexs = [i for i, x in enumerate(text) if x==eos]
|
||||
# print('eos index', eos_indexs,flush=True)
|
||||
if len(eos_indexs)>1: text = text[:eos_indexs[1]]
|
||||
tt_list.append(text)
|
||||
tt_lens = [len(i) for i in tt_list]
|
||||
tt_lens = torch.tensor(tt_lens, dtype=torch.long).to('cuda')
|
||||
tt_pad = torch.tensor([pad_seq(i, eos, max(tt_lens), pad_left=True) for i in tt_list], dtype=torch.long).to('cuda')
|
||||
return tt_pad, tt_lens
|
||||
|
||||
def strip_list(seq, eos_id):
|
||||
l, r = 0, len(seq)-1
|
||||
for i in range(len(seq)):
|
||||
if seq[i] != eos_id:
|
||||
break
|
||||
l = i
|
||||
for i in range(len(seq)-1, -1, -1):
|
||||
if seq[i] != eos_id:
|
||||
break
|
||||
r = i
|
||||
return seq[l+1:r]
|
||||
|
||||
def aug_unlabeled_data(batch, tokz, args, task_name):
|
||||
# * Augment sentence with EDA.
|
||||
input_sen, context_sen, all_sen, ans_sen = [],[],[],[]
|
||||
task_prefix = task_name + ':'
|
||||
raw_sen, ques_sen = [],[]
|
||||
|
||||
for i in batch:
|
||||
text = i['raw_text']
|
||||
ans_text = i['ans_sen']
|
||||
question = i['question']
|
||||
aug_text_list = eda(text, num_aug=args.num_aug)
|
||||
# aug_text_list = eda(text, num_aug=1)
|
||||
for aug_text in aug_text_list:
|
||||
input_text = task_prefix + aug_text + '<QUES>' + question
|
||||
input_sen.append(input_text)
|
||||
context_sen.append(input_text+'<ANS>')
|
||||
# all_sen.append(aug_text+'<ANS>'+ans_text)
|
||||
all_sen.append(aug_text+'<QUES>') # train LM only on the inputs
|
||||
ans_sen.append(ans_text)
|
||||
raw_sen.append(aug_text)
|
||||
ques_sen.append(question)
|
||||
|
||||
input_encoding = tokz(input_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
|
||||
all_encoding = tokz(all_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
|
||||
context_encoding = tokz(context_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
|
||||
ans_encoding = tokz(ans_sen, padding='longest', max_length=args.max_ans_len, truncation=True, return_tensors='pt')
|
||||
raw_text_encoding = tokz(raw_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
|
||||
question_encoding = tokz(ques_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
|
||||
|
||||
|
||||
res = {}
|
||||
res["input_id"], res['input_mask'] = input_encoding.input_ids, input_encoding.attention_mask
|
||||
res["all_id"], res['all_mask'] = all_encoding.input_ids, all_encoding.attention_mask
|
||||
res["context_id"], res['context_mask'] = context_encoding.input_ids, context_encoding.attention_mask
|
||||
res["ans_id"], res['ans_mask'] = ans_encoding.input_ids, ans_encoding.attention_mask
|
||||
res["raw_id"], res['raw_mask'] = raw_text_encoding.input_ids, raw_text_encoding.attention_mask
|
||||
res["ques_id"], res['ques_mask'] = question_encoding.input_ids, question_encoding.attention_mask
|
||||
res['batch_text'] = batch
|
||||
return res
|
||||
|
||||
def get_current_consistency_weight(args, epoch):
|
||||
# Consistency ramp-up from https://arxiv.org/abs/1610.02242
|
||||
return args.consistency * sigmoid_rampup(epoch, args.consistency_rampup)
|
||||
|
||||
def sigmoid_rampup(current, rampup_length):
|
||||
"""Exponential rampup from https://arxiv.org/abs/1610.02242"""
|
||||
if rampup_length == 0:
|
||||
return 1.0
|
||||
else:
|
||||
current = np.clip(current, 0.0, rampup_length)
|
||||
phase = 1.0 - current / rampup_length
|
||||
return float(np.exp(-5.0 * phase * phase))
|
||||
|
||||
|
||||
class KDLoss(nn.Module):
|
||||
def __init__(self, KD_term=0.0, T=1.0):
|
||||
super(KDLoss, self).__init__()
|
||||
assert 0 <= KD_term <=5
|
||||
assert 0 < T
|
||||
self.KD_term = KD_term
|
||||
self.T = T
|
||||
|
||||
def forward(self, output_logits, teacher_logits, label_mask=None):
|
||||
KD_loss = F.kl_div(F.log_softmax(output_logits / self.T, dim=1),
|
||||
F.softmax(teacher_logits / self.T, dim=1), reduction='none')
|
||||
# print('label_mask shape:',label_mask.shape,'loss shape:',KD_loss.shape,flush=True)
|
||||
# KD_loss [batch_size*seq_len, 50528]
|
||||
KD_loss = torch.sum(KD_loss, dim=1)
|
||||
if label_mask is not None:
|
||||
label_mask = label_mask.view(-1)
|
||||
KD_loss = KD_loss.where(label_mask.cuda(), torch.tensor(0.0).cuda())
|
||||
kd_loss = KD_loss.sum() / label_mask.sum()
|
||||
else:
|
||||
kd_loss = KD_loss
|
||||
return kd_loss * self.KD_term * self.T * self.T
|
||||
|
||||
# def forward(self, output_logits, teacher_logits=None):
|
||||
# KD_loss = F.kl_div(F.log_softmax(output_logits / self.T, dim=1),
|
||||
# F.softmax(teacher_logits / self.T, dim=1), reduction='none')
|
||||
# KD_loss = torch.sum(KD_loss, dim=1)
|
||||
# return KD_loss * self.KD_term * self.T * self.T # T^2 is a trick to make KL loss scale invariant to temperature
|
||||
|
||||
|
||||
def kl_divergence(mean1, logvar1, mean2, logvar2):
|
||||
# print(mean1.size(),logvar1.size(),mean2.size(),logvar2.size(),flush=True)
|
||||
exponential = logvar1 - logvar2 - \
|
||||
torch.pow(mean1 - mean2, 2) / logvar2.exp() - \
|
||||
torch.exp(logvar1 - logvar2) + 1
|
||||
result = -0.5 * torch.sum(exponential, tuple(range(1, len(exponential.shape))))
|
||||
return result.mean()
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
from unifymodel.utils import get_logger, seed_everything
|
||||
from unifymodel.trainer import Trainer
|
||||
from unifymodel.dataset import get_datasets
|
||||
import torch
|
||||
from settings import parse_args
|
||||
import os
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup, Conv1D
|
||||
from transformers import T5Tokenizer, T5Model
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
args = parse_args()
|
||||
logger = get_logger(args.log_file)
|
||||
|
||||
if args.local_rank in [0, -1]:
|
||||
logger.info('Pytorch Version: {}'.format(torch.__version__))
|
||||
for k, v in vars(args).items():
|
||||
logger.info("{}= {}".format(k, v))
|
||||
|
||||
# seed everything
|
||||
seed_everything(args.seed)
|
||||
|
||||
cache_dir = os.path.join(args.output_dir, 'model_cache')
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
|
||||
# * Add special tokens.
|
||||
special_tokens_dict = {'additional_special_tokens': ['<ANS>','<QUES>']}
|
||||
out = os.path.join(cache_dir,'out')
|
||||
tokz = T5Tokenizer.from_pretrained('t5-base',cache_dir=out)
|
||||
|
||||
num_spe_token = tokz.add_special_tokens(special_tokens_dict)
|
||||
|
||||
if args.use_task_pmt:
|
||||
pre_token_list = []
|
||||
for task in args.tasks:
|
||||
pre_token_list += [str(task)+':']
|
||||
else:
|
||||
pre_token_list = ['TASK:' for i in range(len(args.tasks))]
|
||||
|
||||
num_add_tokens = tokz.add_tokens(pre_token_list)
|
||||
print('We have added', num_spe_token+num_add_tokens, 'tokens to T5', flush=True)
|
||||
tokz.save_pretrained(out)
|
||||
|
||||
if args.local_rank in [0, -1]:
|
||||
logger.info('Loading datasets...'+'.'*10)
|
||||
datasets = get_datasets(args.data_dir, args.tasks, tokz, max_input_len=args.max_input_len)
|
||||
logger.info('Finish loading datasets!')
|
||||
|
||||
if args.use_memory:
|
||||
memory = Memory()
|
||||
else:
|
||||
memory = None
|
||||
|
||||
trainer = Trainer(args, tokz, datasets, logger, cache_dir, memory=memory)
|
||||
trainer.train(args.tasks)
|
||||
|
Loading…
Reference in New Issue