add: ssll

This commit is contained in:
出蛰 2023-05-26 16:08:26 +08:00
parent 1034f23eec
commit fa8479e963
29 changed files with 6444 additions and 0 deletions

View File

@ -0,0 +1,40 @@
# SSLL
## Introduction
-----
SSLL (Semi-Supervised Lifelong Language Learning) is build by Conversational AI Team, Alibaba DAMO Academy.
The corresponding paper has been published at EMNLP 2022 Findings: "***Semi-Supervised Lifelong Language Learning***".
## SSLL Implementation
-----
### Requirements
```
pip install -r requirements.txt
```
### Dataset
The datasets used in the experiments follows [LAMOL](https://arxiv.org/abs/1909.03329).
### Model Training and Evaluation
```
sh scripts/lltrain.sh
```
The files required for training are under the folder `unifymodel`.
## Citation
----
If you use our code or find SSLL useful for your work, please cite our paper as:\
@inproceedings{zhao2022semi,\
title={Semi-Supervised Lifelong Language Learning},
author={Zhao, Yingxiu and Zheng, Yinhe and Yu, Bowen and Tian, Zhiliang and Lee, Dongkyu and Sun, Jian and Li, Yongbin and Zhang, Nevin L.},
booktitle={Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing: Findings},
year={2022}\
}
LAMOL citation:\
@inproceedings{sun2019lamol,\
title={LAMOL: LAnguage MOdeling for Lifelong Language Learning},
author={Sun, Fan-Keng and Ho, Cheng-Hao and Lee, Hung-Yi},
booktitle={International Conference on Learning Representations},
year={2020}\
}

60
ssll/backtrans.py Normal file
View File

@ -0,0 +1,60 @@
from transformers import MarianMTModel, MarianTokenizer
import torch
# target_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
# tgt_tokz_path='out_BT/tgt_tokz'
# tgt_model_path='out_BT/tgt_model'
# target_tokenizer = MarianTokenizer.from_pretrained(tgt_tokz_path)
# target_model = MarianMTModel.from_pretrained(tgt_model_path)
# # en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
# en_tokz_path='out_BT/en_tokz'
# en_model_path='out_BT/en_model'
# en_tokenizer = MarianTokenizer.from_pretrained(en_tokz_path)
# en_model = MarianMTModel.from_pretrained(en_model_path)
# target_tokenizer.save_pretrained('out_BT/tgt_tokz.pt')
# target_model.save_pretrained('out_BT/tgt_model.pt')
# en_tokenizer.save_pretrained('out_BT/en_tokz.pt')
# en_model.save_pretrained('out_BT/en_model.pt')
# target_model, en_model = target_model.to('cuda'), en_model.to('cuda')
def translate(texts, model, tokenizer, language="fr"):
# Prepare the text data into appropriate format for the model
def template(
text): return f"{text}" if language == "en" else f">>{language}<< {text}"
src_texts = [template(text) for text in texts]
# Tokenize the texts
encoded = tokenizer.prepare_seq2seq_batch(src_texts)
# input_ids = tokenizer.encode(src_texts)
for k, v in encoded.items():
encoded[k] = torch.tensor(v)
encoded = encoded.to('cuda')
# Generate translation using model
translated = model.generate(**encoded)
# Convert the generated tokens indices back into text
translated_texts = tokenizer.batch_decode(
translated, skip_special_tokens=True)
return translated_texts
def back_translate(texts, target_model, target_tokenizer, en_model, en_tokenizer, source_lang="en", target_lang="fr"):
# Translate from source to target language
fr_texts = translate(texts, target_model, target_tokenizer,
language=target_lang)
# Translate from target language back to source language
back_translated_texts = translate(fr_texts, en_model, en_tokenizer,
language=source_lang)
return back_translated_texts
# en_texts = ['This is so cool', 'I hated the food', 'They were very helpful']
# # en_texts = 'I love you.'
# aug_texts = back_translate(en_texts, target_model, target_tokenizer, en_model, en_tokenizer, source_lang="en", target_lang="fr")
# print(aug_texts)

348
ssll/dataset.py Normal file
View File

@ -0,0 +1,348 @@
import torch
import csv
import os
import re
import json
import numpy as np
from settings import parse_args
from eda import *
# from train import special_tokens, special_token_ids
args = parse_args()
class LBIDDataset(torch.utils.data.Dataset):
def __init__(self, task_name, tokz, data_path, ctx_max_len=100, special_token_ids=None):
self.tokz = tokz
self.data_path = data_path
# self.CLS_PROMPT = CLSPrompt()
# self.SlotTagging_PROMPT = SlotTaggingPrompt()
self.max_ans_len = 0
self.special_token_ids = special_token_ids
with open(data_path, "r", encoding='utf-8') as f:
ori_data = [json.loads(i) for i in f.readlines()]
data = []
for i in ori_data:
data += self.parse_example(i)
self.data = []
if len(data) > 0:
self.data = self.data_tokenization(task_name, data)
def get_answer(self, intent): # get answer text from intent
# replace - and _ to make the text seems more natural
ans = intent.replace("-", " ").replace("_", " ")
return ans
def parse_example(self, example):
text = example['userInput']['text']
ans = self.get_answer(example['intent'])
num_aug = 0 # Number of augmented sentences per sample.
if not args.meantc or num_aug == 0:
# print('Run this.', flush=True)
return [(text, ans)]
else:
aug_text_list = eda(text, num_aug=num_aug)
res = [(aug_text, ans) for aug_text in aug_text_list]
res.append((text, ans))
return res
def per_tokenization(self, task_name, d):
# print(d,flush=True)
input_text, ans_text = d
input_id = self.tokz.encode(input_text)
ans_id = self.tokz.encode(ans_text)
self.max_ans_len = max(self.max_ans_len, len(ans_id) + 1)
return {
'all_id': [self.tokz.bos_token_id] + input_id + [self.special_token_ids['ans_token']] + ans_id + [self.tokz.eos_token_id],
'input_id': [self.tokz.bos_token_id] + input_id,
'context_id': [self.tokz.bos_token_id] + input_id + [self.special_token_ids['ans_token']],
'ans_id': ans_id + [self.tokz.eos_token_id]}
def data_tokenization(self, task_name, data):
# print('data',data[:10],flush=True)
data = [self.per_tokenization(task_name, i) for i in data]
return data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
class ULBIDDataset(torch.utils.data.Dataset):
def __init__(self, task_name, tokz, data_path, ctx_max_len=100, special_token_ids=None):
self.tokz = tokz
self.data_path = data_path
self.max_ans_len = 0
self.special_token_ids = special_token_ids
with open(data_path, "r", encoding='utf-8') as f:
data = [json.loads(i) for i in f.readlines()]
data = [self.parse_example(i) for i in data] # [utter, label]
self.data = []
if len(data) > 0:
self.data = self.data_tokenization(task_name, data)
def get_answer(self, intent): # get answer text from intent
# replace - and _ to make the text seems more natural
ans = intent.replace("-", " ").replace("_", " ")
return ans
def parse_example(self, example):
text = example['userInput']['text']
ans = self.get_answer(example['intent'])
return text, ans
def per_tokenization(self, task_name, d):
input_text, ans_text = d
input_id = self.tokz.encode(input_text)
ans_id = self.tokz.encode(ans_text)
self.max_ans_len = max(self.max_ans_len, len(ans_id) + 1)
return {
'all_id': [self.tokz.bos_token_id] + input_id + [self.special_token_ids['ans_token']] + ans_id + [self.tokz.eos_token_id],
'input_id': [self.tokz.bos_token_id] + input_id,
# 'input_id': [self.tokz.bos_token_id] + input_id + [self.special_token_ids['ans_token']],
'context_id': [self.tokz.bos_token_id] + input_id + [self.special_token_ids['ans_token']],
'ans_id': ans_id + [self.tokz.eos_token_id]}
def data_tokenization(self, task_name, data):
data = [self.per_tokenization(task_name, i) for i in data]
return data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
class PinnedBatch:
def __init__(self, data):
self.data = data
def __getitem__(self, k):
return self.data[k]
def __setitem__(self, key, value):
self.data[key] = value
def pin_memory(self):
for k in self.data.keys():
self.data[k] = self.data[k].pin_memory()
return self
def __repr__(self):
return self.data.__repr__()
def __str__(self):
return self.data.__str__()
def items(self):
return self.data.items()
def keys(self):
return self.data.keys()
def pad_seq(seq, pad, max_len, pad_left=False):
if pad_left:
return [pad] * (max_len - len(seq)) + seq
else:
return seq + [pad] * (max_len - len(seq))
class PadBatchSeq:
def __init__(self, pad_id=0):
self.pad_id = pad_id
def __call__(self, batch):
input_id = [i['input_id'] for i in batch]
context_id = [i['context_id'] for i in batch]
all_id = [i['all_id'] for i in batch]
ans_id = [i['ans_id'] for i in batch]
input_lens = [len(i) for i in input_id]
context_lens = [len(i) for i in context_id]
all_lens = [len(i) for i in all_id]
ans_lens = [len(i) for i in ans_id]
input_mask = torch.ByteTensor([[1] * input_lens[i] + [0] * (max(input_lens)-input_lens[i]) for i in range(len(input_id))])
all_mask = torch.ByteTensor([[1] * all_lens[i] + [0] * (max(all_lens)-all_lens[i]) for i in range(len(all_id))])
all_label_mask = torch.ByteTensor([[0] * (context_lens[i]) + [1] * (all_lens[i] - context_lens[i]) + [0] * (max(all_lens)-all_lens[i]) for i in range(len(all_id))]) # Only calculate losses on answer tokens.
res = {}
res["input_id"] = torch.tensor([pad_seq(i, self.pad_id, max(input_lens)) for i in input_id], dtype=torch.long)
res["context_id"] = torch.tensor([pad_seq(i, self.pad_id, max(context_lens), pad_left=True) for i in context_id], dtype=torch.long)
res["all_id"] = torch.tensor([pad_seq(i, self.pad_id, max(all_lens)) for i in all_id], dtype=torch.long)
res["ans_id"] = torch.tensor([pad_seq(i, self.pad_id, max(ans_lens)) for i in ans_id], dtype=torch.long)
res["input_lens"] = torch.tensor(input_lens, dtype=torch.long)
res["context_lens"] = torch.tensor(context_lens, dtype=torch.long)
res["all_lens"] = torch.tensor(all_lens, dtype=torch.long)
res["ans_lens"] = torch.tensor(ans_lens, dtype=torch.long)
res['input_mask'] = input_mask
res['all_mask'] = all_mask
res['all_label_mask'] = all_label_mask
return PinnedBatch(res)
# * Information for different tasks.
if args.data_type == 'intent':
TASK2INFO = {
"banking": {
"dataset_class": LBIDDataset,
"dataset_folder": "banking",
"task_type": "CLS",
},
"clinc": {
"dataset_class": LBIDDataset,
"dataset_folder": "clinc",
"task_type": "CLS",
},
"hwu": {
"dataset_class": LBIDDataset,
"dataset_folder": "hwu",
"task_type": "CLS",
},
"atis": {
"dataset_class": LBIDDataset,
"dataset_folder": "atis",
"task_type": "CLS",
},
"tod": {
"dataset_class": LBIDDataset,
"dataset_folder": "FB_TOD_SF",
"task_type": "CLS",
},
"snips": {
"dataset_class": LBIDDataset,
"dataset_folder": "snips",
"task_type": "CLS",
},
"top_split1": {
"dataset_class": LBIDDataset,
"dataset_folder": "top_split1",
"task_type": "CLS",
},
"top_split2": {
"dataset_class": LBIDDataset,
"dataset_folder": "top_split2",
"task_type": "CLS",
},
"top_split3": {
"dataset_class": LBIDDataset,
"dataset_folder": "top_split3",
"task_type": "CLS",
}
}
else:
TASK2INFO = {
"dstc8": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "dstc8",
"task_type": "SlotTagging",
},
"restaurant8": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "restaurant8",
"task_type": "SlotTagging",
},
"atis": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "atis",
"task_type": "SlotTagging",
},
"mit_movie_eng": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "mit_movie_eng",
"task_type": "SlotTagging",
},
"mit_movie_trivia": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "mit_movie_trivia",
"task_type": "SlotTagging",
},
"mit_restaurant": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "mit_restaurant",
"task_type": "SlotTagging",
},
"snips": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "snips",
"task_type": "SlotTagging",
},
}
def get_unlabel_data(path, task):
info = TASK2INFO[task]
data_path = os.path.join(path, info['dataset_folder'],'unlabel_train.json')
return data_path
def get_unlabel_dict(path, task, tokz, args):
info = TASK2INFO[task]
special_token_ids = {'ans_token':tokz.convert_tokens_to_ids('ANS:')}
if args.data_type == 'intent':
return LBIDDataset(task, tokz, path, args.ctx_max_len, special_token_ids=special_token_ids)
else:
return None
class MixedDataset(LBIDDataset):
def __init__(self, unlabel_data, label_data, tokz, ctx_max_len=100):
self.ctx_max_len = ctx_max_len
self.tokz = tokz
self.max_ans_len = 0
self.special_token_ids = {'ans_token':tokz.convert_tokens_to_ids('ANS:')}
self.data = []
self.data += unlabel_data
self.data += label_data
class AugDataset(LBIDDataset):
def __init__(self, task_name, tokz, label_data_path, unlabel_data_path, ctx_max_len=100, special_token_ids=None):
self.tokz = tokz
self.max_ans_len = 0
self.special_token_ids = special_token_ids
with open(label_data_path, "r", encoding='utf-8') as f:
ori_data = [json.loads(i) for i in f.readlines()]
with open(unlabel_data_path, "r", encoding='utf-8') as f:
ori_data += [json.loads(i) for i in f.readlines()]
data = []
for i in ori_data:
data += self.parse_example(i)
self.data = []
if len(data) > 0:
self.data = self.data_tokenization(task_name, data)
def parse_example(self, example):
text = example['userInput']['text']
aug_text_list = eda(text, num_aug=args.num_aug)
ans = self.get_answer(example['intent'])
res = [(aug_text, ans) for aug_text in aug_text_list]
res.append((text, ans))
return res
def get_datasets(path, tasks, tokz, ctx_max_len=100, special_token_ids=None):
res = {}
for task in tasks:
res[task] = {}
info = TASK2INFO[task]
if args.data_type == "intent":
res[task]['label_train'] = LBIDDataset(
task, tokz, os.path.join(path, info['dataset_folder'], 'label_train.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
res[task]['unlabel_train'] = ULBIDDataset(
task, tokz, os.path.join(path, info['dataset_folder'], 'unlabel_train.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
if args.meantc:
filename = 'unlabel_' + str(args.num_unlabel) + '_train.json'
res[task]['aug_train'] = AugDataset(task, tokz, os.path.join(path, info['dataset_folder'], 'label_train.json'),
os.path.join(path, info['dataset_folder'], filename), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
res[task]['val'] = TASK2INFO[task]['dataset_class'](task, tokz, os.path.join(path, info['dataset_folder'], 'valid.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
res[task]['test'] = TASK2INFO[task]['dataset_class'](task, tokz, os.path.join(path, info['dataset_folder'], 'test.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
return res

5
ssll/decanlp_order.txt Normal file
View File

@ -0,0 +1,5 @@
(wikisql sst woz.en squad srl)
(srl squad wikisql woz.en sst)
(wikisql woz.en srl sst squad)
(srl sst squad woz.en wikisql)
(woz.en sst squad wikisql srl)

42
ssll/divide_data.py Normal file
View File

@ -0,0 +1,42 @@
import json, os, argparse
from settings import parse_args
import random
parser = argparse.ArgumentParser()
parser.add_argument('--train_dir',type=str, help='Divide part of labeled data to unlabeled.')
parser.add_argument('--data_outdir',type=str, help='Directory of divided labeled and unlabeled data.')
parser.add_argument('--num_label',type=int,help='Number of labeled data for each training dataset.')
args = parser.parse_args()
data_outdir = args.data_outdir
train_dir = args.train_dir
K = args.num_label
train_file = os.path.join(train_dir,'ripe_data','train.json')
train_out_label = os.path.join(data_outdir,'label_train.json')
train_out_unlabel = os.path.join(data_outdir,'unlabel_train.json')
with open(train_file,'r', encoding='utf-8') as f:
data = [json.loads(i) for i in f.readlines()]
label_idx_list = random.sample(range(len(data)), K)
label_data = [data[i] for i in range(len(data)) if i in label_idx_list]
print(K,len(label_data))
unlabel_data = [data[i] for i in range(len(data)) if i not in label_idx_list]
with open(train_out_label, 'w', encoding='utf-8') as f:
for i in label_data:
print(json.dumps(i, ensure_ascii=False), file=f)
with open(train_out_unlabel, 'w', encoding='utf-8') as f:
for i in unlabel_data:
print(json.dumps(i, ensure_ascii=False), file=f)
str_list = ['100','500','2000']
for i in str_list:
random.shuffle(unlabel_data)
with open(os.path.join(data_outdir, 'unlabel_'+i+'_train.json'), 'w', encoding='utf-8') as f:
for j in unlabel_data[:int(i)]:
print(json.dumps(j, ensure_ascii=False), file=f)
print('Finish dividing',args.train_dir)

246
ssll/eda.py Normal file
View File

@ -0,0 +1,246 @@
# Easy data augmentation techniques for text classification
# Jason Wei and Kai Zou
import nltk
# nltk.download('wordnet', '/Users/yingxiu/Desktop/Xiu_CODE/Xiu/SCL/nltk_data')
# nltk.download('omw-1.4')
# from nltk_data.corpora import wordnet
# from nltk.corpus import wordnet
nltk.data.path.append('/Users/yingxiu/Desktop/Xiu_CODE/Xiu/SCL/nltk_data')
import re
import random
from random import shuffle
random.seed(1)
#stop words list
stop_words = ['i', 'me', 'my', 'myself', 'we', 'our',
'ours', 'ourselves', 'you', 'your', 'yours',
'yourself', 'yourselves', 'he', 'him', 'his',
'himself', 'she', 'her', 'hers', 'herself',
'it', 'its', 'itself', 'they', 'them', 'their',
'theirs', 'themselves', 'what', 'which', 'who',
'whom', 'this', 'that', 'these', 'those', 'am',
'is', 'are', 'was', 'were', 'be', 'been', 'being',
'have', 'has', 'had', 'having', 'do', 'does', 'did',
'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or',
'because', 'as', 'until', 'while', 'of', 'at',
'by', 'for', 'with', 'about', 'against', 'between',
'into', 'through', 'during', 'before', 'after',
'above', 'below', 'to', 'from', 'up', 'down', 'in',
'out', 'on', 'off', 'over', 'under', 'again',
'further', 'then', 'once', 'here', 'there', 'when',
'where', 'why', 'how', 'all', 'any', 'both', 'each',
'few', 'more', 'most', 'other', 'some', 'such', 'no',
'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too',
'very', 's', 't', 'can', 'will', 'just', 'don',
'should', 'now', '']
#cleaning up text
def get_only_chars(line):
clean_line = ""
line = line.replace("", "")
line = line.replace("'", "")
line = line.replace("-", " ") # replace hyphens with spaces
line = line.replace("\t", " ")
line = line.replace("\n", " ")
line = line.lower()
for char in line:
if char in 'qwertyuiopasdfghjklzxcvbnm ':
clean_line += char
else:
clean_line += ' '
clean_line = re.sub(' +', ' ', clean_line) # delete extra spaces
if clean_line[0] == ' ':
clean_line = clean_line[1:]
return clean_line
########################################################################
# *Synonym replacement
# Replace n words in the sentence with synonyms from wordnet
########################################################################
#for the first time you use wordnet
# import nltk
# nltk.download('wordnet')
def synonym_replacement(words, n):
new_words = words.copy()
random_word_list = list(
set([word for word in words if word not in stop_words]))
random.shuffle(random_word_list)
num_replaced = 0
for random_word in random_word_list:
synonyms = get_synonyms(random_word)
if len(synonyms) >= 1:
synonym = random.choice(list(synonyms))
new_words = [synonym if word == random_word else word for word in new_words]
#print("replaced", random_word, "with", synonym)
num_replaced += 1
if num_replaced >= n: # only replace up to n words
break
#this is stupid but we need it, trust me
sentence = ' '.join(new_words)
new_words = sentence.split(' ')
return new_words
def get_synonyms(word):
synonyms = set()
for syn in wordnet.synsets(word):
for l in syn.lemmas():
synonym = l.name().replace("_", " ").replace("-", " ").lower()
synonym = "".join(
[char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm'])
synonyms.add(synonym)
if word in synonyms:
synonyms.remove(word)
return list(synonyms)
########################################################################
# *Random deletion
# Randomly delete words from the sentence with probability p
########################################################################
def random_deletion(words, p):
#obviously, if there's only one word, don't delete it
if len(words) == 1:
return words
#randomly delete words with probability p
new_words = []
for word in words:
r = random.uniform(0, 1)
if r > p:
new_words.append(word)
#if you end up deleting all words, just return a random word
if len(new_words) == 0 and len(words)>1:
rand_int = random.randint(0, len(words)-1)
return [words[rand_int]]
return new_words
########################################################################
# *Random swap
# Randomly swap two words in the sentence n times
########################################################################
def random_swap(words, n):
new_words = words.copy()
for _ in range(n):
if len(new_words)>1:
new_words = swap_word(new_words)
else:
new_words = words.copy()
return new_words
def swap_word(new_words):
random_idx_1 = random.randint(0, len(new_words)-1)
random_idx_2 = random_idx_1
counter = 0
while random_idx_2 == random_idx_1:
random_idx_2 = random.randint(0, len(new_words)-1)
counter += 1
if counter > 3:
return new_words
new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
return new_words
########################################################################
# *Random insertion
# Randomly insert n words into the sentence
########################################################################
def random_insertion(words, n):
new_words = words.copy()
for _ in range(n):
add_word(new_words)
return new_words
def add_word(new_words):
synonyms = []
counter = 0
while len(synonyms) < 1:
if len(new_words)>1:
random_word = new_words[random.randint(0, len(new_words)-1)]
synonyms = get_synonyms(random_word)
counter += 1
if counter >= 10:
return
random_synonym = synonyms[0]
random_idx = random.randint(0, len(new_words)-1)
new_words.insert(random_idx, random_synonym)
########################################################################
# main data augmentation function
########################################################################
def eda(sentence, alpha_sr=0.0, alpha_ri=0.0, alpha_rs=0.1, p_rd=0.05, num_aug=5):
sentence = get_only_chars(sentence)
words = sentence.split(' ')
words = [word for word in words if word!='']
num_words = len(words)
augmented_sentences = []
num_new_per_technique = int(num_aug/4)+1
#sr
if (alpha_sr > 0):
n_sr = max(1, int(alpha_sr*num_words))
for _ in range(num_new_per_technique):
a_words = synonym_replacement(words, n_sr)
augmented_sentences.append(' '.join(a_words))
#ri
if (alpha_ri > 0):
n_ri = max(1, int(alpha_ri*num_words))
for _ in range(num_new_per_technique):
a_words = random_insertion(words, n_ri)
augmented_sentences.append(' '.join(a_words))
#rs
if (alpha_rs > 0):
n_rs = max(1, int(alpha_rs*num_words))
for _ in range(num_new_per_technique):
a_words = random_swap(words, n_rs)
augmented_sentences.append(' '.join(a_words))
#rd
if (p_rd > 0):
for _ in range(num_new_per_technique):
a_words = random_deletion(words, p_rd)
augmented_sentences.append(' '.join(a_words))
augmented_sentences = [get_only_chars(sentence)
for sentence in augmented_sentences]
shuffle(augmented_sentences)
#trim so that we have the desired number of augmented sentences
if num_aug >= 1:
augmented_sentences = augmented_sentences[:num_aug]
else:
keep_prob = num_aug / len(augmented_sentences)
augmented_sentences = [
s for s in augmented_sentences if random.uniform(0, 1) < keep_prob]
#append the original sentence
augmented_sentences.append(sentence)
shuffle(augmented_sentences)
return augmented_sentences
if __name__ == '__main__':
s = 'I like the cat with white and black fur.'
print(eda(s))

94
ssll/evaluate.py Normal file
View File

@ -0,0 +1,94 @@
import torch
import csv
import os
import json
import logging
from fp16 import FP16_Module
import GPUtil
from collections import OrderedDict
from settings import args, MODEL_CLASS, TOKENIZER, SPECIAL_TOKEN_IDS, init_logging
from settings import MEMORY_FACTOR, LEN_FACTOR, TASK_DICT, MODEL_CONFIG, DATA_ATTRS, SPECIAL_TOKENS, CONFIG_CLASS, CONFIG_NAME
from utils import QADataset, top_k_top_p_filtering, create_dataloader, logits_to_tokens, get_model_dir
from utils import sample_sequence, remove_id, get_gen_token, lll_unbound_setting
from metrics import *
logger = logging.getLogger(__name__)
def test_one_to_one(curr_task, task_eval, model, score_dict,last=True, eval_data=None, out_dir=None):
logger.info('Start to evaluate %s'%task_eval)
max_ans_len = eval_data.max_ans_len + 1
# print('Maximum answer length',max_ans_len,flush=True)
if out_dir is not None:
if not os.path.exists(out_dir):
os.makedirs(out_dir)
out_dir = os.path.join(
out_dir, f'{task_eval}.json')
out_dir_content = []
with torch.no_grad():
model.set_active_adapters(task_eval)
# print('model of task: {}'.format(task),list(model.parameters()), flush=True)
model.eval()
pred_ans_all, gold_ans_all = [], []
context_all = []
all_pred_task_name = []
for i, data in enumerate(eval_data):
example = data['context_id'].cuda()
example_mask = data['context_mask'].cuda()
if out_dir is not None:
context_all.append(example)
gold_ans = data['ans_id'].cuda()
# Print model inputs to check.
# if i==0:
# print('context:',self.tokz.decode(example[0]),flush=True)
# print('ans:',self.tokz.decode(gold_ans[0]),flush=True)
pred_ans, _ = get_answer(self.tokz, model, example, example_mask, self.args.max_ans_len, args=self.args, is_eval=True)
# print('context shape',context.shape,flush=True)
# pred_ans = pred_ans[:,context.shape[1]:]
pred_ans = pred_ans[:,1:] # Remove the first <pad> token.
# print('pred',pred_ans,flush=True)
# print('gold',gold_ans,flush=True) # output tensors
pred_ans_all.append(pred_ans)
gold_ans_all.append(gold_ans)
# * Compute score.
# print('task_eval',task_eval,flush=True)
get_test_score(task_eval, qa_results, score_dict)
return model, score_dict
def get_test_score(task_eval,qa_results,score_dict):
# score = ours_compute_metrics(task_eval, qa_results)
# score = compute_metrics(task_eval, qa_results)
score = compute_metrics(
qa_results,
bleu='iwslt.en.de' in task_eval or 'multinli.in.out' in task_eval,
dialogue='woz.en' in task_eval,
rouge='cnn_dailymail' in task_eval,
logical_form='wikisql' in task_eval,
corpus_f1='zre' in task_eval
)
score_dict[task_eval] = score
def training_test_one_to_many(model,curr_task,step,last=True, dataloaders=None, args=None): # task_load like 'hwu' 'banking'
model.eval()
score_dicts = []
logger.info("task: {}, step: {}".format(curr_task, step))
score_dict = {k:None for k in args.tasks}
with torch.no_grad():
for task_eval in args.tasks:
if args.test_overfit:
eval_data = dataloaders[task]['label_train']
else:
eval_data = dataloaders[task]['test']
eval_dir = os.path.join(args.output_dir, curr_task)
test_one_to_one(curr_task, task_eval, model, score_dict, last=last, eval_data=eval_data, out_dir=eval_dir)
logger.info("score: {}".format(score_dict))
score_dicts.append(score_dict)
return score_dicts[0]

88
ssll/final_score.py Normal file
View File

@ -0,0 +1,88 @@
import json
import os
import argparse
# from settings import parse_args
parser = argparse.ArgumentParser()
parser.add_argument('--tasks',nargs='+', default=['ag'])
parser.add_argument('--output_dir',type=str)
parser.add_argument('--res_dir',type=str, default='scores')
parser.add_argument('--exp',type=str)
args = parser. parse_args()
# TODO: Calculate average scores and average forgetting. =================================================================================
test_res_path = os.path.join(args.res_dir, args.exp+'_res.txt')
test_res = open(test_res_path, 'a')
# * Load saved scores.
score_all_time=[]
with open(os.path.join(args.output_dir, "metrics.json"),"r") as f:
score_dict = [json.loads(row) for row in f.readlines()]
# print(f'SCORE DICT {score_dict}',flush=True)
# * Calculate Average Score of all tasks observed.
# score_last_time = score_dict[-1] # last time and last epoch
# average_score = score_last_time['Average']
# print(f'Last time scores for all tasks:\n {score_last_time}', file=test_res)
# print('Last time scores for all tasks:\n',score_last_time, flush=True)
# print(f"Average score is {average_score['score']}; teacher average score is {average_score['tc_score']}", file=test_res)
# print(f"Average score is {average_score['score']}; teacher average score is {average_score['tc_score']}", flush=True)
# * Calculate the max scores of all tasks
tasks_max_scores = {}
for task in args.tasks:
for row in score_dict:
key_name = list(row.keys())[0]
if task+'_finish' in key_name:
if task not in tasks_max_scores:
# print(row[task+'_finish'])
tasks_max_scores[task] = max(row[task+'_finish']['score'],row[task+'_finish']['tc_score'])
if 'prev_tasks_scores' in key_name:
for term in row['prev_tasks_scores']:
if task in term:
new_task_score = max(term[task]['score'], term[task]['tc_score'])
if new_task_score > tasks_max_scores[task]:
tasks_max_scores[task] = new_task_score
print(f'Tasks max scores are {tasks_max_scores}')
print(f'Average score is {sum(tasks_max_scores.values())/len(tasks_max_scores.values())}', file=test_res)
print(f'Average score is {sum(tasks_max_scores.values())/len(tasks_max_scores.values())}')
# * Calculate Backward transfer.
each_finish_score = []
each_finish_tc_score = []
last_task_name = args.tasks[-1]
for row in score_dict:
key_name = list(row.keys())[0]
if '_finish' in key_name and last_task_name not in key_name:
# print(f'ROW: {row}',flush=True)
for k,v in row.items():
# print(v, flush=True)
each_finish_score.append(v['score'])
each_finish_tc_score.append(v['tc_score'])
average_wo_last = score_dict[-2]
# bwt = - min(sum(each_finish_score) / len(each_finish_score), sum(each_finish_tc_score) / len(each_finish_tc_score)) + max(average_wo_last['Average_wo_curr_task']['score'], average_wo_last['Average_wo_curr_task']['tc_score'])
# bwt = - min(sum(each_finish_score) / len(each_finish_score), sum(each_finish_tc_score) / len(each_finish_tc_score)) + min(average_wo_last['Average_wo_curr_task']['score'], average_wo_last['Average_wo_curr_task']['tc_score'])
# tc_bwt = - sum(each_finish_tc_score) / len(each_finish_tc_score) + max(average_wo_last['Average_wo_curr_task']['score'], average_wo_last['Average_wo_curr_task']['tc_score'])
prev_max_scores = [tasks_max_scores[task] for task in args.tasks[:-1]]
print(f'Each finish score {each_finish_score}', flush=True)
bwt = (- min(sum(each_finish_score), sum(each_finish_tc_score)) + sum(prev_max_scores)) / len(args.tasks[:-1])
print(f"Backward transfer is {bwt}", file=test_res)
print(f"Backward transfer is {bwt}", flush=True)
# # * Calculate Forward transfer
# each_initial_score = []
# # each_initial_tc_score = []
# for row in score_dict:
# if '_initial' in list(row.keys())[0]:
# for k,v in row.items():
# each_initial_score.append(v['score'])
# # each_initial_tc_score.append(v['tc_score'])
# fwt = sum(each_initial_score) / len(each_initial_score)
# # tc_fwt = sum(each_initial_tc_score) / len(each_initial_tc_score)
# # * Print out final results of scores.
# print(f'Forward transfer is {fwt}', file=test_res)
# print(f'Forward transfer is {fwt}', flush=True)
# print(average_score['tc_score'], tc_bwt,file=test_res)

434
ssll/lamol_dataset.py Normal file
View File

@ -0,0 +1,434 @@
import torch
import csv
import os
import re
import json
import numpy as np
from settings import parse_args
from eda import *
# from train import special_tokens, special_token_ids
args = parse_args()
class LBIDDataset(torch.utils.data.Dataset):
def __init__(self, task_name, tokz, data_path, ctx_max_len=100, special_token_ids=None):
self.tokz = tokz
self.data_path = data_path
self.max_ans_len = 10
self.special_token_ids = special_token_ids
with open(data_path, "r", encoding='utf-8') as f:
ori_data = [json.loads(i) for i in f.readlines()]
data = []
for i in ori_data:
data += self.parse_example(i)
self.data = []
if len(data) > 0:
self.data = self.data_tokenization(task_name, data)
def get_answer(self, target): # get answer text from intent
# replace - and _ to make the text seems more natural
ans = target.replace("-", " ").replace("_", " ")
return ans
def parse_example(self, example):
text = example['input'].replace('-',' ').replace('_',' ').replace('\\',' ')
ans = self.get_answer(example['output'])
num_aug = 0 # Number of augmented sentences per sample.
if not args.meantc or num_aug == 0:
return [(text, ans)]
else:
aug_text_list = eda(text, num_aug=num_aug)
res = [(aug_text, ans) for aug_text in aug_text_list]
res.append((text, ans))
return res
def per_tokenization(self, task_name, d):
# print(d,flush=True)
input_text, ans_text = d
if args.use_task_pmt:
# prefix_id = self.tokz.encode(task_name+':')
input_sen = task_name+':' + input_text
else:
# prefix_id = self.tokz.encode('TASK:')
input_sen = 'TASK:' + input_text
context_sen = input_sen + '<ANS>'
all_sen = input_text + '<ANS>' + ans_text
ans_sen = ans_text
# ans_id = self.tokz.encode(ans_text)
# self.max_ans_len = max(self.max_ans_len, len(ans_id) + 1)
# return {
# 'all_id': self.tokz.encode(all_sen),
# 'input_id': self.tokz.encode(input_sen),
# 'context_id': self.tokz.encode(context_sen),
# 'ans_id': ans_id
# }
return {
'all_sen': all_sen,
'input_sen': input_sen,
'context_sen': context_sen,
'ans_sen': ans_sen,
}
def data_tokenization(self, task_name, data):
# print('data',data[:10],flush=True)
data = [self.per_tokenization(task_name, i) for i in data]
return data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
class ULBIDDataset(torch.utils.data.Dataset):
def __init__(self, task_name, tokz, data_path, ctx_max_len=100, special_token_ids=None):
self.tokz = tokz
self.data_path = data_path
self.max_ans_len = 0
self.special_token_ids = special_token_ids
with open(data_path, "r", encoding='utf-8') as f:
data = [json.loads(i) for i in f.readlines()]
data = [self.parse_example(i) for i in data] # [utter, label]
self.data = []
if len(data) > 0:
self.data = self.data_tokenization(task_name, data)
def get_answer(self, target): # get answer text from intent
# replace - and _ to make the text seems more natural
ans = target.replace("-", " ").replace("_", " ")
return ans
def parse_example(self, example):
text = example['input'].replace('-',' ').replace('_',' ').replace('\\',' ')
ans = self.get_answer(example['output'])
return text, ans
def per_tokenization(self, task_name, d):
input_text, ans_text = d
if args.use_task_pmt:
# prefix_id = self.tokz.encode(task_name+':')
input_sen = task_name+':' + input_text
else:
# prefix_id = self.tokz.encode('TASK:')
input_sen = 'TASK:' + input_text
context_sen = input_sen + '<ANS>'
# all_sen = context_sen + ans_text
all_sen = input_text + '<ANS>' + ans_text
ans_sen = ans_text
# ans_id = self.tokz.encode(ans_text)
# self.max_ans_len = max(self.max_ans_len, len(ans_id) + 1)
# return {
# 'all_id': self.tokz.encode(all_sen),
# 'input_id': self.tokz.encode(input_sen),
# 'context_id': self.tokz.encode(context_sen),
# 'ans_id': ans_id
# }
return {
'all_sen': all_sen,
'input_sen': input_sen,
'context_sen': context_sen,
'ans_sen': ans_sen,
}
def data_tokenization(self, task_name, data):
data = [self.per_tokenization(task_name, i) for i in data]
return data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
class PinnedBatch:
def __init__(self, data):
self.data = data
def __getitem__(self, k):
return self.data[k]
def __setitem__(self, key, value):
self.data[key] = value
def pin_memory(self):
for k in self.data.keys():
self.data[k] = self.data[k].pin_memory()
return self
def __repr__(self):
return self.data.__repr__()
def __str__(self):
return self.data.__str__()
def items(self):
return self.data.items()
def keys(self):
return self.data.keys()
def pad_seq(seq, pad, max_len, pad_left=False):
if pad_left:
return [pad] * (max_len - len(seq)) + seq
else:
return seq + [pad] * (max_len - len(seq))
class OldPadBatchSeq:
def __init__(self, pad_id=0):
self.pad_id = pad_id
def __call__(self, batch):
input_id = [i['input_id'] for i in batch]
context_id = [i['context_id'] for i in batch]
all_id = [i['all_id'] for i in batch]
ans_id = [i['ans_id'] for i in batch]
input_lens = [len(i) for i in input_id]
context_lens = [len(i) for i in context_id]
all_lens = [len(i) for i in all_id]
ans_lens = [len(i) for i in ans_id]
input_mask = torch.ByteTensor([[1] * input_lens[i] + [0] * (max(input_lens)-input_lens[i]) for i in range(len(input_id))])
all_mask = torch.ByteTensor([[1] * all_lens[i] + [0] * (max(all_lens)-all_lens[i]) for i in range(len(all_id))])
all_label_mask = torch.ByteTensor([[0] * (context_lens[i]) + [1] * (all_lens[i] - context_lens[i]) + [0] * (max(all_lens)-all_lens[i]) for i in range(len(all_id))]) # Only calculate losses on answer tokens.
res = {}
res["input_id"] = torch.tensor([pad_seq(i, self.pad_id, max(input_lens)) for i in input_id], dtype=torch.long)
res["context_id"] = torch.tensor([pad_seq(i, self.pad_id, max(context_lens), pad_left=True) for i in context_id], dtype=torch.long)
res["all_id"] = torch.tensor([pad_seq(i, self.pad_id, max(all_lens)) for i in all_id], dtype=torch.long)
res["ans_id"] = torch.tensor([pad_seq(i, self.pad_id, max(ans_lens)) for i in ans_id], dtype=torch.long)
res["input_lens"] = torch.tensor(input_lens, dtype=torch.long)
res["context_lens"] = torch.tensor(context_lens, dtype=torch.long)
res["all_lens"] = torch.tensor(all_lens, dtype=torch.long)
res["ans_lens"] = torch.tensor(ans_lens, dtype=torch.long)
res['input_mask'] = input_mask
res['all_mask'] = all_mask
res['all_label_mask'] = all_label_mask
return PinnedBatch(res)
class PadBatchSeq:
def __init__(self, pad_id=0, tokz=None):
self.pad_id = pad_id
self.tokz = tokz
def __call__(self, batch):
input_sen = [i['input_sen'] for i in batch]
context_sen = [i['context_sen'] for i in batch]
all_sen = [i['all_sen'] for i in batch]
ans_sen = [i['ans_sen'] for i in batch]
input_encoding = self.tokz(input_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
all_encoding = self.tokz(all_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
context_encoding = self.tokz(context_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
ans_encoding = self.tokz(ans_sen, padding='longest', max_length=args.max_ans_len, truncation=True, return_tensors='pt')
res = {}
res["input_id"], res['input_mask'] = input_encoding.input_ids, input_encoding.attention_mask
res["all_id"], res['all_mask'] = all_encoding.input_ids, all_encoding.attention_mask
res["context_id"], res['context_mask'] = context_encoding.input_ids, context_encoding.attention_mask
res["ans_id"], res['ans_mask'] = ans_encoding.input_ids, ans_encoding.attention_mask
return PinnedBatch(res)
# * Information for different tasks.
if args.data_type == 'intent':
TASK2INFO = {
"banking": {
"dataset_class": LBIDDataset,
"dataset_folder": "banking",
"task_type": "CLS",
},
"clinc": {
"dataset_class": LBIDDataset,
"dataset_folder": "clinc",
"task_type": "CLS",
},
"hwu": {
"dataset_class": LBIDDataset,
"dataset_folder": "hwu",
"task_type": "CLS",
},
"atis": {
"dataset_class": LBIDDataset,
"dataset_folder": "atis",
"task_type": "CLS",
},
"tod": {
"dataset_class": LBIDDataset,
"dataset_folder": "FB_TOD_SF",
"task_type": "CLS",
},
"snips": {
"dataset_class": LBIDDataset,
"dataset_folder": "snips",
"task_type": "CLS",
},
"top_split1": {
"dataset_class": LBIDDataset,
"dataset_folder": "top_split1",
"task_type": "CLS",
},
"top_split2": {
"dataset_class": LBIDDataset,
"dataset_folder": "top_split2",
"task_type": "CLS",
},
"top_split3": {
"dataset_class": LBIDDataset,
"dataset_folder": "top_split3",
"task_type": "CLS",
}
}
elif args.data_type =='slot':
TASK2INFO = {
"dstc8": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "dstc8",
"task_type": "SlotTagging",
},
"restaurant8": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "restaurant8",
"task_type": "SlotTagging",
},
"atis": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "atis",
"task_type": "SlotTagging",
},
"mit_movie_eng": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "mit_movie_eng",
"task_type": "SlotTagging",
},
"mit_movie_trivia": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "mit_movie_trivia",
"task_type": "SlotTagging",
},
"mit_restaurant": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "mit_restaurant",
"task_type": "SlotTagging",
},
"snips": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "snips",
"task_type": "SlotTagging",
}
}
elif args.data_type =='tc':
TASK2INFO = {
"ag": {
"dataset_class": LBIDDataset,
"dataset_folder": "ag",
"task_type": "tc",
},
'dbpedia':{
'dataset_class':LBIDDataset,
'dataset_folder':'dbpedia',
'task_type':'tc',
},
'yelp':{
'dataset_class':LBIDDataset,
'dataset_folder':'yelp',
'task_type':'tc',
},
'yahoo':{
'dataset_class':LBIDDataset,
'dataset_folder':'yahoo',
'task_type':'tc',
},
'snips':{
'dataset_class':LBIDDataset,
'dataset_folder':'snips',
'task_type':'tc',
},
}
def get_unlabel_data(path, task):
info = TASK2INFO[task]
data_path = os.path.join(path, info['dataset_folder'],'unlabel_train.json')
return data_path
def get_unlabel_dict(path, task, tokz, args):
info = TASK2INFO[task]
special_token_ids = {'ans_token':tokz.convert_tokens_to_ids('ANS:')}
if args.data_type == 'intent':
return LBIDDataset(task, tokz, path, args.ctx_max_len, special_token_ids=special_token_ids)
else:
return None
class MixedDataset(LBIDDataset):
def __init__(self, unlabel_data, label_data, tokz, ctx_max_len=100):
self.ctx_max_len = ctx_max_len
self.tokz = tokz
self.max_ans_len = 0
self.special_token_ids = {'ans_token':tokz.convert_tokens_to_ids('ANS:')}
self.data = []
self.data += unlabel_data
self.data += label_data
class AugDataset(LBIDDataset):
def __init__(self, task_name, tokz, label_data_path, unlabel_data_path, ctx_max_len=100, special_token_ids=None):
self.tokz = tokz
self.max_ans_len = 0
self.special_token_ids = special_token_ids
with open(label_data_path, "r", encoding='utf-8') as f:
ori_data = [json.loads(i) for i in f.readlines()]
with open(unlabel_data_path, "r", encoding='utf-8') as f:
ori_data += [json.loads(i) for i in f.readlines()]
data = []
for i in ori_data:
data += self.parse_example(i)
self.data = []
if len(data) > 0:
self.data = self.data_tokenization(task_name, data)
def parse_example(self, example):
text = example['userInput']['text']
aug_text_list = eda(text, num_aug=args.num_aug)
ans = self.get_answer(example['intent'])
res = [(aug_text, ans) for aug_text in aug_text_list]
res.append((text, ans))
return res
def get_datasets(path, tasks, tokz, ctx_max_len=100, special_token_ids=None):
res = {}
for task in tasks:
res[task] = {}
info = TASK2INFO[task]
if args.data_type == "intent" or args.data_type == "tc":
res[task]['label_train'] = LBIDDataset(
task, tokz, os.path.join(path, info['dataset_folder'], 'label_train.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
if args.use_unlabel:
res[task]['unlabel_train'] = ULBIDDataset(
task, tokz, os.path.join(path, info['dataset_folder'], 'unlabel_train.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
if args.meantc:
filename = 'unlabel_' + str(args.num_unlabel) + '_train.json'
res[task]['aug_train'] = AugDataset(task, tokz, os.path.join(path, info['dataset_folder'], 'label_train.json'),
os.path.join(path, info['dataset_folder'], filename), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
res[task]['val'] = TASK2INFO[task]['dataset_class'](task, tokz, os.path.join(path, info['dataset_folder'], 'test.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
res[task]['test'] = TASK2INFO[task]['dataset_class'](task, tokz, os.path.join(path, info['dataset_folder'], 'test.json'), ctx_max_len=ctx_max_len, special_token_ids=special_token_ids)
return res

374
ssll/metrics.py Normal file
View File

@ -0,0 +1,374 @@
import collections
import string
import re
import numpy as np
AGG_OPS = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
COND_OPS = ['=', '>', '<', 'OP']
COND_OPS = ['=', '>', '<']
def lcsstr(string1, string2):
answer = 0
len1, len2 = len(string1), len(string2)
for i in range(len1):
match = 0
for j in range(len2):
if (i + j < len1 and string1[i + j] == string2[j]):
match += 1
if match > answer:
answer = match
else:
match = 0
return answer
def lcs(X, Y):
m = len(X)
n = len(Y)
L = [[None]*(n + 1) for i in range(m + 1)]
for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0 :
L[i][j] = 0
elif X[i-1] == Y[j-1]:
L[i][j] = L[i-1][j-1]+1
else:
L[i][j] = max(L[i-1][j], L[i][j-1])
return L[m][n] / max(m, n) + lcsstr(X, Y) / 1e4 - min(m, n) / 1e8
def normalize_text(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def to_delta_state(line):
delta_state = {'inform': {}, 'request': {}}
try:
if line.lower() == 'none' or line.strip() == '' or line.strip() == ';':
return delta_state
inform, request = [[y.strip() for y in x.strip().split(',')] for x in line.split(';')]
inform_pairs = {}
for i in inform:
try:
k, v = i.split(':')
inform_pairs[k.strip()] = v.strip()
except:
pass
delta_state = {'inform': inform_pairs, 'request': request}
except:
pass
finally:
return delta_state
def update_state(state, delta):
for act, slot in delta.items():
state[act] = slot
return state
def dict_cmp(d1, d2):
def cmp(a, b):
for k1, v1 in a.items():
if k1 not in b:
return False
else:
if v1 != b[k1]:
return False
return True
return cmp(d1, d2) and cmp(d2, d1)
def to_lf(s, table):
aggs = [y.lower() for y in AGG_OPS]
agg_to_idx = {x: i for i, x in enumerate(aggs)}
conditionals = [y.lower() for y in COND_OPS]
headers_unsorted = [(y.lower(), i) for i, y in enumerate(table['header'])]
headers = [(y.lower(), i) for i, y in enumerate(table['header'])]
headers.sort(reverse=True, key=lambda x: len(x[0]))
condition_s, conds = None, []
if 'where' in s:
s, condition_s = s.split('where', 1)
s = ' '.join(s.split()[1:-2])
s_no_agg = ' '.join(s.split()[1:])
sel, agg = None, 0
lcss, idxs = [], []
for col, idx in headers:
lcss.append(lcs(col, s))
lcss.append(lcs(col, s_no_agg))
idxs.append(idx)
lcss = np.array(lcss)
max_id = np.argmax(lcss)
sel = idxs[max_id // 2]
if max_id % 2 == 1: # with agg
agg = agg_to_idx[s.split()[0]]
full_conditions = []
if not condition_s is None:
pattern = '|'.join(COND_OPS)
split_conds_raw = re.split(pattern, condition_s)
split_conds_raw = [conds.strip() for conds in split_conds_raw]
split_conds = [split_conds_raw[0]]
for i in range(1, len(split_conds_raw)-1):
split_conds.extend(re.split('and', split_conds_raw[i]))
split_conds += [split_conds_raw[-1]]
for i in range(0, len(split_conds), 2):
cur_s = split_conds[i]
lcss = []
for col in headers:
lcss.append(lcs(col[0], cur_s))
max_id = np.argmax(np.array(lcss))
split_conds[i] = headers[max_id][0]
for i, m in enumerate(re.finditer(pattern, condition_s)):
split_conds[2*i] = split_conds[2*i] + ' ' + m.group()
split_conds = [' '.join(split_conds[2*i:2*i+2]) for i in range(len(split_conds)//2)]
condition_s = ' and '.join(split_conds)
condition_s = ' ' + condition_s + ' '
for idx, col in enumerate(headers):
condition_s = condition_s.replace(' ' + col[0] + ' ', ' Col{} '.format(col[1]))
condition_s = condition_s.strip()
for idx, col in enumerate(conditionals):
new_s = []
for t in condition_s.split():
if t == col:
new_s.append('Cond{}'.format(idx))
else:
new_s.append(t)
condition_s = ' '.join(new_s)
s = condition_s
conds = re.split('(Col\d+ Cond\d+)', s)
if len(conds) == 0:
conds = [s]
conds = [x for x in conds if len(x.strip()) > 0]
full_conditions = []
for i, x in enumerate(conds):
if i % 2 == 0:
x = x.split()
col_num = int(x[0].replace('Col', ''))
opp_num = int(x[1].replace('Cond', ''))
full_conditions.append([col_num, opp_num])
else:
x = x.split()
if x[-1] == 'and':
x = x[:-1]
x = ' '.join(x)
if 'Col' in x:
new_x = []
for t in x.split():
if 'Col' in t:
idx = int(t.replace('Col', ''))
t = headers_unsorted[idx][0]
new_x.append(t)
x = new_x
x = ' '.join(x)
if 'Cond' in x:
new_x = []
for t in x.split():
if 'Cond' in t:
idx = int(t.replace('Cond', ''))
t = conditionals[idx]
new_x.append(t)
x = new_x
x = ' '.join(x)
full_conditions[-1].append(x.replace(' ', ''))
logical_form = {'sel': sel, 'conds': full_conditions, 'agg': agg}
return logical_form
def computeLFEM(greedy, answer):
count = 0
correct = 0
text_answers = []
for idx, (g, ex) in enumerate(zip(greedy, answer)):
count += 1
text_answers.append([ex['answer'].lower()])
try:
gt = ex['sql']
conds = gt['conds']
lower_conds = []
for c in conds:
lc = c
lc[2] = str(lc[2]).lower().replace(' ', '')
lower_conds.append(lc)
gt['conds'] = lower_conds
lf = to_lf(g, ex['table'])
correct += lf == gt
except Exception as e:
continue
return (correct / count) * 100, text_answers
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def f1_score(prediction, ground_truth):
prediction_tokens = prediction.split()
ground_truth_tokens = ground_truth.split()
common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match(prediction, ground_truth):
# print(f'pred {prediction}, gt {ground_truth}')
return prediction == ground_truth
def computeF1(outputs, targets):
return sum([metric_max_over_ground_truths(f1_score, o, t) for o, t in zip(outputs, targets)]) / len(outputs) * 100
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for idx, ground_truth in enumerate(ground_truths):
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def computeEM(outputs, targets):
outs = [metric_max_over_ground_truths(exact_match, o, t) for o, t in zip(outputs, targets)]
# print(f'EM outs: {outs}',flush=True)
return sum(outs) / len(outputs) * 100
def score(answer, gold):
if len(gold) > 0:
gold = set.union(*[simplify(g) for g in gold])
answer = simplify(answer)
tp, tn, sys_pos, real_pos = 0, 0, 0, 0
if answer == gold:
if not ('unanswerable' in gold and len(gold) == 1):
tp += 1
else:
tn += 1
if not ('unanswerable' in answer and len(answer) == 1):
sys_pos += 1
if not ('unanswerable' in gold and len(gold) == 1):
real_pos += 1
return np.array([tp, tn, sys_pos, real_pos])
def simplify(answer):
return set(''.join(c for c in t if c not in string.punctuation) for t in answer.strip().lower().split()) - {'the', 'a', 'an', 'and', ''}
def computeCF1(greedy, answer):
scores = np.zeros(4)
for g, a in zip(greedy, answer):
scores += score(g, a)
tp, tn, sys_pos, real_pos = scores.tolist()
total = len(answer)
if tp == 0:
p = r = f = 0.0
else:
p = tp / float(sys_pos)
r = tp / float(real_pos)
f = 2 * p * r / (p + r)
return f * 100, p * 100, r * 100
def computeDialogue(greedy, answer):
examples = []
for idx, (g, a) in enumerate(zip(greedy, answer)):
examples.append((a[0], g, a[1], idx))
#examples.sort()
turn_request_positives = 0
turn_goal_positives = 0
joint_goal_positives = 0
ldt = None
for ex in examples:
if ldt is None or ldt.split('_')[:-1] != ex[0].split('_')[:-1]:
state, answer_state = {}, {}
ldt = ex[0]
delta_state = to_delta_state(ex[1])
answer_delta_state = to_delta_state(ex[2])
state = update_state(state, delta_state['inform'])
answer_state = update_state(answer_state, answer_delta_state['inform'])
if dict_cmp(state, answer_state):
joint_goal_positives += 1
if delta_state['request'] == answer_delta_state['request']:
turn_request_positives += 1
if dict_cmp(delta_state['inform'], answer_delta_state['inform']):
turn_goal_positives += 1
joint_goal_em = joint_goal_positives / len(examples) * 100
turn_request_em = turn_request_positives / len(examples) * 100
turn_goal_em = turn_goal_positives / len(examples) * 100
answer = [(x[-1], x[-2]) for x in examples]
#answer.sort()
answer = [[x[1]] for x in answer]
return joint_goal_em, turn_request_em, turn_goal_em, answer
def compute_metrics(data, rouge=False, bleu=False, corpus_f1=False, logical_form=False, dialogue=False):
greedy = [datum[0] for datum in data] # prediction
answer = [datum[1] for datum in data] # ground truth
for i in range(3):
pred, tgt = data[i]
print('pred', pred,'tgt',tgt, flush=True)
metric_keys = []
metric_values = []
if logical_form:
lfem, answer = computeLFEM(greedy, answer)
metric_keys += ['lfem']
metric_values += [lfem]
em = computeEM(greedy, answer)
metric_keys.append('em')
metric_values.append(em)
norm_greedy = [normalize_text(g) for g in greedy]
norm_answer = [[normalize_text(a) for a in ans] for ans in answer]
nf1 = computeF1(norm_greedy, norm_answer)
nem = computeEM(norm_greedy, norm_answer)
metric_keys.extend(['nf1', 'nem'])
metric_values.extend([nf1, nem])
if corpus_f1:
corpus_f1, precision, recall = computeCF1(norm_greedy, norm_answer)
metric_keys += ['corpus_f1', 'precision', 'recall']
metric_values += [corpus_f1, precision, recall]
if dialogue:
joint_goal_em, request_em, turn_goal_em, answer = computeDialogue(greedy, answer)
avg_dialogue = (joint_goal_em + request_em) / 2
metric_keys += ['joint_goal_em', 'turn_request_em', 'turn_goal_em', 'avg_dialogue']
metric_values += [joint_goal_em, request_em, turn_goal_em, avg_dialogue]
metric_dict = collections.OrderedDict(list(zip(metric_keys, metric_values)))
return metric_dict
task2metric_dict = {'woz.en':'avg_dialogue',
'wikisql':'lfem',
'squad':'nf1',
'sst':'em',
'srl':'nf1',
'ag':'em',
'yahoo':'em',
'yelp':'em',
'amazon':'em',
'dbpedia':'em'
}

View File

@ -0,0 +1,18 @@
import json
import os, sys
oridir="DATA_Divided/DATA_100/ag"
outpath="DATA_Divided/DATA_100/ag/pretrain.txt"
datatype_list=['label_train','unlabel_train']
with open(outpath,'w') as fw:
for datatype in datatype_list:
datapath = os.path.join(oridir,datatype+'.json')
with open(datapath,'r') as f:
data = [json.loads(i) for i in f.readlines()]
for row in data:
# print(json.dumps(row['input'].strip('"'), ensure_ascii=False),file=fw)
print(row['input'.strip('"')], file=fw)

1047
ssll/pretrain.py Normal file

File diff suppressed because it is too large Load Diff

25
ssll/random_order.py Normal file
View File

@ -0,0 +1,25 @@
import random
# seed_list = [2, 102, 22, 32, 42]
seed_list = [2, 32, 42]
# intent_list = ['atis', 'banking', 'clinc', 'hwu', 'snips', 'top_split1', 'top_split2' ,'top_split3']
# res_order=open('six_order_intent.txt','w')
# prelist = ['atis', 'dstc8', 'mit_movie_eng', 'snips', 'mit_movie_trivia', 'mit_restaurant']
# prelist = ['yahoo','yelp','ag','dbpedia','amazon']
# prelist = ['srl','sst','wikisql','woz.en','squad']
prelist = ['srl','sst','wikisql','woz.en','squad','yahoo','yelp','ag','dbpedia','amazon']
# res_order=open('text_classification_order.txt','w')
res_order=open('mix_order.txt','w')
res_list = []
for i in range(len(seed_list)):
random.seed(seed_list[i])
random.shuffle(prelist)
# print(intent_list)
# res_list.append(intent_list)
# print(intent_list,file=res_order)
print('('+' '.join(prelist)+')',file=res_order)
# print('\n',file=res_order)
res_order.close()

14
ssll/requirements.txt Normal file
View File

@ -0,0 +1,14 @@
adapter_transformers==3.0.0
datasets==2.7.1
flax==0.6.2
GPUtil==1.4.0
huggingface_hub==0.11.0
jax==0.3.25
nltk==3.6.7
numpy==1.19.5
optax==0.1.4
rouge==1.0.1
scikit_learn==1.1.3
torch==1.8.2
tqdm==4.64.0
transformers==4.2.2

262
ssll/scripts/lltrain.sh Normal file
View File

@ -0,0 +1,262 @@
gpuid=$1
# K=$2
K=100
R=20 #* Ratio of unlabeled data over labeled data.
# K=5000
run=K${K}_U${R}_run
ord=0
# data_type=tc #!
# data_type=decanlp
data_type=mix
exp=semi_${data_type}_$run
task_type=semi #!
use_unlabel=true
meantc=true #* Use mean teacher for semi-supervised learning.
echo exp: $exp
echo data_type: $data_type
echo K: $K U: $R
echo order: $ord
# * tc ===========================================
if [[ $data_type == tc ]]
then
if [[ $ord == 0 ]]
then
task_list=(ag yelp dbpedia amazon yahoo) #ord0
elif [[ $ord == 1 ]]
then
task_list=(yahoo amazon ag dbpedia yelp) #ord1
elif [[ $ord == 2 ]]
then
task_list=(ag dbpedia yahoo yelp amazon) #ord2
elif [[ $ord == 3 ]]
then
task_list=(yahoo yelp amazon dbpedia ag) #ord3
elif [[ $ord == 4 ]]
then
task_list=(dbpedia yelp amazon ag yahoo) #ord4
fi
max_input_len=150
max_ans_len=20
epochs=120
# epochs=75
train_bs=16
gradient_accumulation_steps=1
eval_bs=16
feedback_threshold=0.05
simi_tau=0.8
# simi_tau=0.65
# add_confidence_selection=
add_confidence_selection=true
fi
# * decanlp ======================================
if [[ $data_type == decanlp ]]
then
if [[ $ord == 0 ]]
then
task_list=(wikisql sst woz.en squad srl) #ord0
# task_list=(woz.en sst)
elif [[ $ord == 1 ]]
then
task_list=(srl squad wikisql woz.en sst) #ord1
elif [[ $ord == 2 ]]
then
task_list=(wikisql woz.en srl sst squad) #ord2
elif [[ $ord == 3 ]]
then
task_list=(srl sst squad woz.en wikisql) #ord3
elif [[ $ord == 4 ]]
then
task_list=(woz.en sst squad wikisql srl) #ord4
fi
max_input_len=512
max_ans_len=100
epochs=200
# epochs=100 #!test
train_bs=4
gradient_accumulation_steps=4
eval_bs=16
feedback_threshold=0.1
# feedback_threshold=0.000001 #!test
simi_tau=0.6
add_confidence_selection=true
fi
# * MIX ================================================
if [[ $data_type == mix ]]
then
if [[ $ord == 0 ]]
then
task_list=(yahoo amazon woz.en squad yelp ag wikisql dbpedia sst srl) #ord0
elif [[ $ord == 1 ]]
then
task_list=(yelp wikisql yahoo sst srl ag dbpedia woz.en squad amazon) #ord1
elif [[ $ord == 2 ]]
then
task_list=(woz.en sst yahoo squad ag dbpedia amazon srl yelp wikisql) #ord2
fi
max_input_len=512
max_ans_len=100
epochs=200
train_bs=4
gradient_accumulation_steps=4
eval_bs=16
feedback_threshold=0.1
simi_tau=0.6
add_confidence_selection=true
fi
# *----------------------------------------------
echo task list: ${task_list[@]}
# random_initialization=true
random_initialization=
num_label=$K
unlabel_ratio=$R
use_task_pmt=true
test_all=true
# test_all= #!test
# debug_use_unlabel=true
debug_use_unlabel=
# evaluate_zero_shot=true
evaluate_zero_shot=
lr=2e-4
# warmup_epoch=1
warmup_epoch=2 # * warm up epoch
freeze_plm=true
# freeze_plm=
evalstep=500000
# test_overfit=true #! Use label_train for evaluation
test_overfit=
# * Unlabel hyper-params.
gen_replay=true #* Perform generative replay
# gen_replay=
pseudo_data_ratio=0.1
# pseudo_data_ratio=0.01 #!test
construct_memory=true #* Construct memory (Forward Aug)
# construct_memory=
backward_augment=true #* Backward Aug
# backward_augment=
back_kneighbors=3 # for backward_augment to retrieve the current unlabel memory
# forward_augment=true #* Forward Aug
forward_augment=
kneighbors=3 #
# kneighbors=5 # for forward_augment to retrieve the old memory
select_adapter=
# select_adapter=true
unlabel_amount='1'
# unlabel_amount=$unlabel_ratio
# unlabel_amount='5'
# ungamma=0.1
ungamma=0.01
add_unlabel_lm_loss=true
# add_unlabel_lm_loss=
add_label_lm_loss=true
# add_label_lm_loss=
accum_grad_iter=$unlabel_amount
# feedback_threshold=0.5
# lm_lambda=0.25
lm_lambda=0.5
KD_temperature=2
KD_term=1
diff_question=true
pseudo_tau=1.5
num_aug=3 # for EDA input augmentation
consistency=10
consistency_rampup=30
# consistency_rampup=10
# ema_decay=0.90
ema_decay=0.95
stu_feedback=true
input_aug=true
rdrop=true
model_aug=true
# model_aug=
# rdrop=
# *--------------------------------------------------
# datadir=./data_lamol/TC
if [[ $data_type == tc ]]
then
datadir=../../DATA/MYDATA_DIVIDED/TC/label_$K
elif [[ $data_type == decanlp ]]
then
datadir=../../DATA/MYDATA_DIVIDED/decaNLP/label_$K
elif [[ $data_type == mix ]]
then
datadir=../../DATA/MYDATA_DIVIDED/MIX/label_$K
fi
output=outputs/${data_type}/$exp/ord$ord
tb_log_dir=tb_logs/${data_type}/$exp/ord$ord
log=logs/${data_type}/${exp}_ord${ord}.log
err=logs/${data_type}/${exp}_ord${ord}.err
mkdir -p logs/${data_type} outputs $output $tb_log_dir
# *--------------------------------------------------
# python_file=tc_train.py
python_file=unitrain.py
# TODO: Running ================================================
CUDA_VISIBLE_DEVICES=$gpuid \
python $python_file \
--gpu=$gpuid \
--data_dir=$datadir \
--output_dir=$output \
--tb_log_dir=$tb_log_dir \
--tasks ${task_list[*]} \
--experiment=$exp \
--num_train_epochs=$epochs \
--train_batch_size=$train_bs \
--eval_batch_size=$eval_bs \
--eval_steps=$evalstep \
--data_type=$data_type \
--use_unlabel=$use_unlabel \
--pseudo_tau=$pseudo_tau \
--warmup_epoch=$warmup_epoch \
--ungamma=$ungamma \
--unlabel_amount=$unlabel_amount \
--meantc=$meantc \
--consistency=$consistency \
--num_aug=$num_aug \
--consistency_rampup=$consistency_rampup \
--test_overfit=$test_overfit \
--ema_decay=$ema_decay \
--num_label=$num_label \
--use_task_pmt=$use_task_pmt \
--freeze_plm=$freeze_plm \
--input_aug=$input_aug \
--model_aug=$model_aug \
--stu_feedback=$stu_feedback \
--rdrop=$rdrop \
--test_all=$test_all \
--unlabel_ratio=$unlabel_ratio \
--gen_replay=$gen_replay \
--add_unlabel_lm_loss=$add_unlabel_lm_loss \
--add_label_lm_loss=$add_label_lm_loss \
--lr=$lr \
--accum_grad_iter=$accum_grad_iter \
--feedback_threshold=$feedback_threshold \
--lm_lambda=$lm_lambda \
--kneighbors=$kneighbors \
--construct_memory=$construct_memory \
--KD_term=$KD_term \
--diff_question=$diff_question \
--KD_temperature=$KD_temperature \
--add_confidence_selection=$add_confidence_selection \
--debug_use_unlabel=$debug_use_unlabel \
--gradient_accumulation_steps=$gradient_accumulation_steps \
--max_input_len=$max_input_len \
--max_ans_len=$max_ans_len \
--backward_augment=$backward_augment \
--back_kneighbors=$back_kneighbors \
--similarity_tau=$simi_tau \
--select_adapter=$select_adapter \
--pseudo_data_ratio=$pseudo_data_ratio \
--forward_augment=$forward_augment \
--evaluate_zero_shot=$evaluate_zero_shot \
--random_initialization=$random_initialization \
> $log 2> $err
#! not log test

163
ssll/settings.py Normal file
View File

@ -0,0 +1,163 @@
import os
import argparse
import torch
def parse_args():
parser = argparse.ArgumentParser()
# * New arguments for semi-supervised continual learning.
parser.add_argument('--newmm_size', default=0.2, type=float, help='Different memory size for storing new task unlabeled data.')
parser.add_argument('--random_initialization',default=False, type=bool)
parser.add_argument('--evaluate_zero_shot',default=False, type=bool)
parser.add_argument('--forward_augment',default=False, type=bool, help='Whether apply forward augmentation from old constructed memory.')
parser.add_argument('--backward_augment', type=bool, default=False, help='Whether use the current task unlabeled data to augment the old tasks.')
parser.add_argument('--select_adapter',default=False, type=bool, help='Whether select the previous adapter for the new task initialization.')
parser.add_argument('--similarity_tau',default=0.6, type=float, help='Similarity threshold of cosine for KNN.')
parser.add_argument('--back_kneighbors',default=5, type=int, help='Retrieve K nearest neighbors from the current unlabeled memory.')
parser.add_argument('--diff_question', type=bool, default=False, help='If true, the dataset has different questions for each sample.')
parser.add_argument('--debug_use_unlabel', default=False, type=bool, help='Always use unlabeled data for debugging.')
parser.add_argument('--add_confidence_selection', type=bool, default=False, help='Whether add confidence selection for pseudo label.')
parser.add_argument('--kneighbors', default=5, type=int, help='Retrieve K nearest neighbors from the memory.')
parser.add_argument('--construct_memory',default=False, help='Construct memory for generated pseudo data from old tasks.')
parser.add_argument('--aug_neighbors_from_memory', default=False, help='Whether augment retrieved neighbors inputs from memory.')
parser.add_argument('--lm_lambda',type=float, default=0.5, help='The coefficient of language modeling loss.')
parser.add_argument('--accum_grad_iter', type=bool, default=1, help='Accumulation steps for gradients.')
parser.add_argument('--add_label_lm_loss', type=bool, default=False, help='Whether compute lm loss on the labeled inputs.')
parser.add_argument('--add_unlabel_lm_loss', type=bool, default=False, help='Whether compute lm loss on the unlabeled inputs.')
parser.add_argument('--pretrain_epoch',type=int, default=20, help='Pretrain first for e.g.20 epochs before finetuning.')
parser.add_argument('--pretrain_first', type=bool, default=False, help='Whether to perform pretraining before each task finetuning.')
parser.add_argument('--test_all', type=bool, default=False, help='Whether evaluate all test data after training the last epoch.')
parser.add_argument('--add_pretrain_with_ft', type=bool, default=False, help='Whether add pretraining MLM loss during finetuning.')
parser.add_argument('--rdrop', default=False, type=bool, help='Whether use R-drop as the consistency regularization, where dropout is used as model augmentation.')
parser.add_argument('--stu_feedback', default=False,type=bool, help='Whether consisder feedback from student to teacher on labeled datasets.')
parser.add_argument('--feedback_threshold', default=1e-1, type=float, help='Threshold of student feedback to determine whether to use teacher to teach the student.')
parser.add_argument('--dropout_rate',default=0.3, type=float, help='Modify dropout rate for model augmentation with dropout.')
parser.add_argument('--freeze_plm', default=False, type=bool, help='Whether to freeze the pretrained LM and only finetuning the added adapters.')
parser.add_argument('--max_input_len',default=512, type=int, help='Max number of input tokens.')
parser.add_argument('--max_ans_len', default=100, type=int, help='Max number of answer tokens.')
parser.add_argument('--use_task_pmt',default=False, type=bool, help='Whether use task specific token as the prefix.')
parser.add_argument('--use_ema', default=False, type=bool, help='Use EMA for fixmatch.')
parser.add_argument('--fixmatch', default=False, type=bool, help='Use fixmatch to solve semi-supervised learning.')
parser.add_argument('--num_label', default=500, type=int, help='Number of labeled data we use.')
parser.add_argument('--unlabel_ratio', default=500, type=int, help='Ratio of unlabeled data over labeled data we use.')
parser.add_argument('--logit_distance_cost', default=-1, type=float, help='let the student model have two outputs and use an MSE loss between the logits with the given weight (default: only have one output)')
parser.add_argument('--test_overfit', default=False, type=bool, help='Test whether the model can overfit on labeled train dataset.')
parser.add_argument('--num_aug', type=int, default=3, help='Number of augmented sentences per sample with EDA.')
parser.add_argument('--ema_decay', default=0.999, type=float, metavar='ALPHA', help='ema variable decay rate (default: 0.999)')
parser.add_argument('--consistency_rampup', type=int, default=30, metavar='EPOCHS', help='length of the consistency loss ramp-up')
parser.add_argument('--consistency', type=float, default=100.0, metavar='WEIGHT', help='use consistency loss with given weight (default: None)')
parser.add_argument('--meantc', type=bool, default=False, help='Whether to use mean teacher to deal with SSL.')
parser.add_argument('--only_update_prompts', type=bool, default=False, help='Only update the prompt tokens params.')
parser.add_argument('--pl_sample_num', type=int, default=1, help='Number of sampling for pseudo labeling.')
parser.add_argument('--pl_sampling', type=bool, default=False, help='Perform pseudo labeling with sampling for unlabeled data. That means, sample several times for each example, and choose the one with the lowest ppl.')
parser.add_argument('--unlabel_amount', type=str, default='1', help='How much unlabeled data used per labeled batch.')
parser.add_argument('--ungamma', type=float, default=0.2, help='Weight added to the unlabeled loss.')
parser.add_argument('--KD_term', type=float, default=1, help="Control how many teacher model signal is passed to student.")
parser.add_argument('--KD_temperature', type=float, default=2.0, help="Temperature used to calculate KD loss.")
parser.add_argument('--consist_regularize', type=bool, default=False, help='Whether to use consistency regularization for SSL.')
parser.add_argument('--stu_epochs', type=int, default=1, help='Number of epochs for training student model.')
parser.add_argument('--noisy_stu', type=bool, default=False, help='Whether to use noisy student self-training framework for SSL, generate pseudo labels for all unlabeled data once.')
parser.add_argument('--online_noisy_stu', type=bool, default=False, help='Noisy student self-training framework with PL and CR. Perform pseudo labeling online.')
parser.add_argument('--naive_pseudo_labeling', type=bool, default=False, help='Whether to use pseudo-labeling for semi-supervised learning.')
parser.add_argument('--only_pseudo_selection', type=bool, default=False, help='Whether to select pseudo labels with high confidence.')
parser.add_argument('--pretrain_ul_input', type=bool, default=False, help='Only pretrain the inputs of unlabeled data.')
parser.add_argument('--use_unlabel', type=bool, default=False, help='Whether to use unlabeled data during training.')
parser.add_argument('--use_infix', type=bool, default=False, help='Whether to use infix prompts.')
parser.add_argument('--use_prefix', type=bool, default=False, help='Whether to use prefix prompts.')
parser.add_argument('--preseqlen', type=int, default=5, help='Number of prepended prefix tokens.')
parser.add_argument('--train_dir',type=str, help='Divide part of labeled data to unlabeled.')
parser.add_argument('--data_outdir',type=str, help='Directory of divided labeled and unlabeled data.')
parser.add_argument('--pseudo_tau', type=float, default=0.6, help='Threshold to choose pseudo labels with high confidence.')
parser.add_argument('--input_aug', type=bool, default=False, help='Whether to use EDA on inputs.')
parser.add_argument('--model_aug', type=bool, default=False, help='Whether to use dropout on the model as the data augmentation.')
parser.add_argument('--warmup_epoch', type=int, default=10, help='Not use unlabeled data before training certain epochs.')
# * Old arguments.
parser.add_argument('--data_type',type=str,default='intent')
parser.add_argument('--use_memory', default=False, type=bool, help="Whether store the learned latent variables z and other info into the memory.")
parser.add_argument('--experiment', type=str)
parser.add_argument('--tasks',nargs='+', default=['banking'])
parser.add_argument("--data_dir", default="./PLL_DATA/", type=str, help="The path to train/dev/test data files.") # Default parameters are set based on single GPU training
parser.add_argument("--output_dir", default="./output/dstc", type=str,help="The output directory where the model checkpoints and predictions will be written.")
parser.add_argument("--tb_log_dir", default="./tb_logs/dstc", type=str,help="The tensorboard output directory.")
parser.add_argument("--res_dir",default="./res", type=str,help="The path to save scores of experiments.")
parser.add_argument("--gene_batch_size", default=16, type=int) # For generation.
parser.add_argument('--top_k', type=int, default=100)
parser.add_argument('--top_p', type=float, default=0.95)
parser.add_argument('--do_train',type=bool, default=True)
parser.add_argument('--gen_replay',type=bool, default=False, help='Whether use generative replay to avoid forgetting.')
parser.add_argument('--model_path', type=str, help='pretrained model path to local checkpoint')
parser.add_argument('--generate_dur_train', type=bool, help='Generate reconstructed input utterances during training with CVAE.')
parser.add_argument('--temperature',type=float, default=0.95)
parser.add_argument('--pseudo_data_ratio', type=float, default=0.05, help="How many pseudo data to generate for each learned task")
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--model_type', type=str, default='cvae', choices=['cvae', 'ae_vae_fusion'])
parser.add_argument('--warmup', type=int, default=100, help="Amount of iterations to warmup, then decay. (-1 for no warmup and decay)")
parser.add_argument("--train_batch_size", default=16, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument("--eval_batch_size", default=16, type=int, help="Batch size per GPU/CPU for evaluation.")
parser.add_argument('--switch-time', type=float, default=0, help="Percentage of iterations to spend on short sequence training.")
parser.add_argument('--load', type=str, help='path to load model from') # , default='out/test/'
parser.add_argument('--workers', default=1, type=int, metavar='N', help='number of data loading workers')
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--no_gpu', action="store_true")
# * KL cost annealing, increase beta from beta_0 to 1 in beta_warmup steps
parser.add_argument('--beta_0', default=1.00, type=float)
parser.add_argument('--beta_warmup', type=int, default=50000)
parser.add_argument('--cycle', type=int, default=1000)
parser.add_argument("--filename",type=str,help="Data original file to be preprocessed.")
parser.add_argument("--init_model_name_or_path", default="./dir_model/gpt2", type=str, help="Path to init pre-trained model")
parser.add_argument("--num_workers", default=1, type=int, help="workers used to process data")
parser.add_argument("--local_rank", help='used for distributed training', type=int, default=-1)
# Other parameters
parser.add_argument("--ctx_max_len", default=128, type=int, help="Maximum input length for the sequence")
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--weight_decay", default=0.00, type=float, help="Weight deay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--num_train_epochs", default=10, type=int, help="Total number of training epochs for each task.")
parser.add_argument("--warmup_steps", default=-1, type=int, help="Linear warmup step. Will overwrite warmup_proportion.")
parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Linear warmup over warmup_proportion * steps.")
parser.add_argument("--nouse_scheduler", action='store_true', help="dont use get_linear_schedule_with_warmup, use unchanged lr")
parser.add_argument('--logging_steps', type=int, default=10, help="Log every X updates steps.")
parser.add_argument('--save_epochs', type=float, default=2, help="Save checkpoint every X epochs.")
parser.add_argument('--eval_steps', type=int, default=20, help="Eval current model every X steps.")
parser.add_argument('--eval_times_per_task', type=float, default=10, help="How many times to eval in each task, will overwrite eval_steps.")
parser.add_argument('--seed', type=int, default=42, help="Seed for everything")
parser.add_argument("--log_eval_res", default=True, help="Whether to log out results in evaluation process")
parser.add_argument("--debug", action="store_true", help="Use debug mode")
#
args = parser. parse_args()
args.log_file = os.path.join(args.output_dir, 'log.txt')
if args.local_rank in [0, -1]:
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.tb_log_dir, exist_ok=True)
else:
while (not os.path.isdir(args.output_dir)) or (not os.path.isdir(args.tb_log_dir)):
pass
if args.debug:
args.logging_steps = 1
torch.manual_seed(0)
torch.backends.cudnn.deterministric = True
# setup distributed training
distributed = (args.local_rank != -1)
if distributed:
print(args.local_rank)
# torch.cuda.set_device(0)
torch.cuda.set_device(args.local_rank)
args.device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(
backend='nccl', init_method='env://')
else:
if torch.cuda.is_available():
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
args.device = torch.device("cpu")
args.num_train_epochs = {task: args.num_train_epochs for task in args.tasks}
return args

15
ssll/single_score.py Normal file
View File

@ -0,0 +1,15 @@
import json
import os
from settings import parse_args
args = parse_args()
score_all_time = []
with open(os.path.join(args.output_dir, "metrics.json"),"r") as f:
score_all_time = [json.loads(row) for row in f.readlines()]
all_scores = []
for row in score_all_time:
value = list(row.values())[0]['intent_acc']
all_scores.append(value)
print('%.5f'%(max(all_scores)))

View File

@ -0,0 +1,5 @@
(ag yelp dbpedia amazon yahoo)
(yahoo amazon ag dbpedia yelp)
(ag dbpedia yahoo yelp amazon)
(yahoo yelp amazon dbpedia ag)
(dbpedia yelp amazon ag yahoo)

60
ssll/tools/backtrans.py Normal file
View File

@ -0,0 +1,60 @@
from transformers import MarianMTModel, MarianTokenizer
import torch
# target_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
# tgt_tokz_path='out_BT/tgt_tokz'
# tgt_model_path='out_BT/tgt_model'
# target_tokenizer = MarianTokenizer.from_pretrained(tgt_tokz_path)
# target_model = MarianMTModel.from_pretrained(tgt_model_path)
# # en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
# en_tokz_path='out_BT/en_tokz'
# en_model_path='out_BT/en_model'
# en_tokenizer = MarianTokenizer.from_pretrained(en_tokz_path)
# en_model = MarianMTModel.from_pretrained(en_model_path)
# target_tokenizer.save_pretrained('out_BT/tgt_tokz.pt')
# target_model.save_pretrained('out_BT/tgt_model.pt')
# en_tokenizer.save_pretrained('out_BT/en_tokz.pt')
# en_model.save_pretrained('out_BT/en_model.pt')
# target_model, en_model = target_model.to('cuda'), en_model.to('cuda')
def translate(texts, model, tokenizer, language="fr"):
# Prepare the text data into appropriate format for the model
def template(
text): return f"{text}" if language == "en" else f">>{language}<< {text}"
src_texts = [template(text) for text in texts]
# Tokenize the texts
encoded = tokenizer.prepare_seq2seq_batch(src_texts)
# input_ids = tokenizer.encode(src_texts)
for k, v in encoded.items():
encoded[k] = torch.tensor(v)
encoded = encoded.to('cuda')
# Generate translation using model
translated = model.generate(**encoded)
# Convert the generated tokens indices back into text
translated_texts = tokenizer.batch_decode(
translated, skip_special_tokens=True)
return translated_texts
def back_translate(texts, target_model, target_tokenizer, en_model, en_tokenizer, source_lang="en", target_lang="fr"):
# Translate from source to target language
fr_texts = translate(texts, target_model, target_tokenizer,
language=target_lang)
# Translate from target language back to source language
back_translated_texts = translate(fr_texts, en_model, en_tokenizer,
language=source_lang)
return back_translated_texts
# en_texts = ['This is so cool', 'I hated the food', 'They were very helpful']
# # en_texts = 'I love you.'
# aug_texts = back_translate(en_texts, target_model, target_tokenizer, en_model, en_tokenizer, source_lang="en", target_lang="fr")
# print(aug_texts)

42
ssll/tools/divide_data.py Normal file
View File

@ -0,0 +1,42 @@
import json, os, argparse
from settings import parse_args
import random
parser = argparse.ArgumentParser()
parser.add_argument('--train_dir',type=str, help='Divide part of labeled data to unlabeled.')
parser.add_argument('--data_outdir',type=str, help='Directory of divided labeled and unlabeled data.')
parser.add_argument('--num_label',type=int,help='Number of labeled data for each training dataset.')
args = parser.parse_args()
data_outdir = args.data_outdir
train_dir = args.train_dir
K = args.num_label
train_file = os.path.join(train_dir,'ripe_data','train.json')
train_out_label = os.path.join(data_outdir,'label_train.json')
train_out_unlabel = os.path.join(data_outdir,'unlabel_train.json')
with open(train_file,'r', encoding='utf-8') as f:
data = [json.loads(i) for i in f.readlines()]
label_idx_list = random.sample(range(len(data)), K)
label_data = [data[i] for i in range(len(data)) if i in label_idx_list]
print(K,len(label_data))
unlabel_data = [data[i] for i in range(len(data)) if i not in label_idx_list]
with open(train_out_label, 'w', encoding='utf-8') as f:
for i in label_data:
print(json.dumps(i, ensure_ascii=False), file=f)
with open(train_out_unlabel, 'w', encoding='utf-8') as f:
for i in unlabel_data:
print(json.dumps(i, ensure_ascii=False), file=f)
str_list = ['100','500','2000']
for i in str_list:
random.shuffle(unlabel_data)
with open(os.path.join(data_outdir, 'unlabel_'+i+'_train.json'), 'w', encoding='utf-8') as f:
for j in unlabel_data[:int(i)]:
print(json.dumps(j, ensure_ascii=False), file=f)
print('Finish dividing',args.train_dir)

244
ssll/tools/eda.py Normal file
View File

@ -0,0 +1,244 @@
# Easy data augmentation techniques for text classification
# Jason Wei and Kai Zou
import nltk
# nltk.download('wordnet', '/Users/yingxiu/Desktop/Xiu_CODE/Xiu/SCL/nltk_data')
# nltk.download('omw-1.4')
# from nltk_data.corpora import wordnet
# from nltk.corpus import wordnet
nltk.data.path.append('/Users/yingxiu/Desktop/Xiu_CODE/Xiu/SCL/nltk_data')
import re
import random
from random import shuffle
random.seed(1)
#stop words list
stop_words = ['i', 'me', 'my', 'myself', 'we', 'our',
'ours', 'ourselves', 'you', 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his',
'himself', 'she', 'her', 'hers', 'herself',
'it', 'its', 'itself', 'they', 'them', 'their',
'theirs', 'themselves', 'what', 'which', 'who',
'whom', 'this', 'that', 'these', 'those', 'am',
'is', 'are', 'was', 'were', 'be', 'been', 'being',
'have', 'has', 'had', 'having', 'do', 'does', 'did',
'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or',
'because', 'as', 'until', 'while', 'of', 'at',
'by', 'for', 'with', 'about', 'against', 'between',
'into', 'through', 'during', 'before', 'after',
'above', 'below', 'to', 'from', 'up', 'down', 'in',
'out', 'on', 'off', 'over', 'under', 'again',
'further', 'then', 'once', 'here', 'there', 'when',
'where', 'why', 'how', 'all', 'any', 'both', 'each',
'few', 'more', 'most', 'other', 'some', 'such', 'no',
'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too',
'very', 's', 't', 'can', 'will', 'just', 'don',
'should', 'now', '']
#cleaning up text
def get_only_chars(line):
clean_line = ""
line = line.replace("", "")
line = line.replace("'", "")
line = line.replace("-", " ") # replace hyphens with spaces
line = line.replace("\t", " ")
line = line.replace("\n", " ")
line = line.lower()
for char in line:
if char in 'qwertyuiopasdfghjklzxcvbnm ':
clean_line += char
else:
clean_line += ' '
clean_line = re.sub(' +', ' ', clean_line) # delete extra spaces
if clean_line[0] == ' ':
clean_line = clean_line[1:]
return clean_line
########################################################################
# *Synonym replacement
# Replace n words in the sentence with synonyms from wordnet
########################################################################
#for the first time you use wordnet
# import nltk
# nltk.download('wordnet')
def synonym_replacement(words, n):
new_words = words.copy()
random_word_list = list(
set([word for word in words if word not in stop_words]))
random.shuffle(random_word_list)
num_replaced = 0
for random_word in random_word_list:
synonyms = get_synonyms(random_word)
if len(synonyms) >= 1:
synonym = random.choice(list(synonyms))
new_words = [synonym if word == random_word else word for word in new_words]
#print("replaced", random_word, "with", synonym)
num_replaced += 1
if num_replaced >= n: # only replace up to n words
break
#this is stupid but we need it, trust me
sentence = ' '.join(new_words)
new_words = sentence.split(' ')
return new_words
def get_synonyms(word):
synonyms = set()
for syn in wordnet.synsets(word):
for l in syn.lemmas():
synonym = l.name().replace("_", " ").replace("-", " ").lower()
synonym = "".join(
[char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm'])
synonyms.add(synonym)
if word in synonyms:
synonyms.remove(word)
return list(synonyms)
########################################################################
# *Random deletion
# Randomly delete words from the sentence with probability p
########################################################################
def random_deletion(words, p):
#obviously, if there's only one word, don't delete it
if len(words) == 1:
return words
#randomly delete words with probability p
new_words = []
for word in words:
r = random.uniform(0, 1)
if r > p:
new_words.append(word)
#if you end up deleting all words, just return a random word
if len(new_words) == 0 and len(words)>1:
rand_int = random.randint(0, len(words)-1)
return [words[rand_int]]
return new_words
########################################################################
# *Random swap
# Randomly swap two words in the sentence n times
########################################################################
def random_swap(words, n):
new_words = words.copy()
for _ in range(n):
if len(new_words) > 1:
new_words = swap_word(new_words)
else:
new_words = words
return new_words
def swap_word(new_words):
random_idx_1 = random.randint(0, len(new_words)-1)
random_idx_2 = random_idx_1
counter = 0
while random_idx_2 == random_idx_1:
random_idx_2 = random.randint(0, len(new_words)-1)
counter += 1
if counter > 3:
return new_words
new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
return new_words
########################################################################
# *Random insertion
# Randomly insert n words into the sentence
########################################################################
def random_insertion(words, n):
new_words = words.copy()
for _ in range(n):
add_word(new_words)
return new_words
def add_word(new_words):
synonyms = []
counter = 0
while len(synonyms) < 1:
if len(new_words)>1:
random_word = new_words[random.randint(0, len(new_words)-1)]
synonyms = get_synonyms(random_word)
counter += 1
if counter >= 10:
return
random_synonym = synonyms[0]
random_idx = random.randint(0, len(new_words)-1)
new_words.insert(random_idx, random_synonym)
########################################################################
# main data augmentation function
########################################################################
def eda(sentence, alpha_sr=0.0, alpha_ri=0.0, alpha_rs=0.1, p_rd=0.05, num_aug=5):
sentence = get_only_chars(sentence)
words = sentence.split(' ')
words = [word for word in words if word!='']
num_words = len(words)
augmented_sentences = []
num_new_per_technique = int(num_aug/4)+1
#sr
if (alpha_sr > 0):
n_sr = max(1, int(alpha_sr*num_words))
for _ in range(num_new_per_technique):
a_words = synonym_replacement(words, n_sr)
augmented_sentences.append(' '.join(a_words))
#ri
if (alpha_ri > 0):
n_ri = max(1, int(alpha_ri*num_words))
for _ in range(num_new_per_technique):
a_words = random_insertion(words, n_ri)
augmented_sentences.append(' '.join(a_words))
#rs
if (alpha_rs > 0):
n_rs = max(1, int(alpha_rs*num_words))
for _ in range(num_new_per_technique):
a_words = random_swap(words, n_rs)
augmented_sentences.append(' '.join(a_words))
#rd
if (p_rd > 0):
for _ in range(num_new_per_technique):
a_words = random_deletion(words, p_rd)
augmented_sentences.append(' '.join(a_words))
augmented_sentences = [get_only_chars(sentence)
for sentence in augmented_sentences]
shuffle(augmented_sentences)
#trim so that we have the desired number of augmented sentences
if num_aug >= 1:
augmented_sentences = augmented_sentences[:num_aug]
else:
keep_prob = num_aug / len(augmented_sentences)
augmented_sentences = [s for s in augmented_sentences if random.uniform(0, 1) < keep_prob]
#append the original sentence
augmented_sentences.append(sentence)
shuffle(augmented_sentences)
return augmented_sentences
if __name__ == '__main__':
s = 'I like the cat with white and black fur.'
print(eda(s))

View File

@ -0,0 +1,47 @@
import os
import json
import sys
ori_dir = 'lamol_data'
task_names = ['ag','dbpedia','yelp','yahoo']
for task in task_names:
dir_path = os.path.join('data_lamol','TC',task)
if not os.path.exists(task) and not os.path.isdir(dir_path):
os.makedirs(dir_path)
# Convert train-data
train_file = os.path.join(ori_dir,task+'_to_squad-train-v2.0.json')
with open(train_file,'r', encoding='utf-8') as f:
data = json.load(f)['data']
train_outfile = os.path.join(dir_path,'train.json')
with open(train_outfile,'w', encoding='utf-8') as fw:
for sample in data:
sample_dict = {}
x = sample['paragraphs'][0]['context'] + ' ' + sample['paragraphs'][0]['qas'][0]['question']
y = sample['paragraphs'][0]['qas'][0]['answers'][0]['text']
sample_dict['input'] = x
sample_dict['output'] = y
sample_dict['question'] = sample['paragraphs'][0]['qas'][0]['question']
sample_dict['text_wo_question'] = sample['paragraphs'][0]['context']
print(json.dumps(sample_dict, ensure_ascii=False), file=fw)
# Convert test-data
test_file = os.path.join(ori_dir,task+'_to_squad-test-v2.0.json')
with open(test_file,'r', encoding='utf-8') as f:
data = json.load(f)['data']
test_outfile = os.path.join(dir_path,'test.json')
with open(test_outfile,'w', encoding='utf-8') as fw:
for sample in data:
sample_dict = {}
x = sample['paragraphs'][0]['context'] + ' ' + sample['paragraphs'][0]['qas'][0]['question']
y = sample['paragraphs'][0]['qas'][0]['answers'][0]['text']
sample_dict['input'] = x
sample_dict['output'] = y
sample_dict['question'] = sample['paragraphs'][0]['qas'][0]['question']
sample_dict['text_wo_question'] = sample['paragraphs'][0]['context']
print(json.dumps(sample_dict, ensure_ascii=False), file=fw)
print('Finishing dealing with ',task,flush=True)

652
ssll/unifymodel/dataset.py Normal file
View File

@ -0,0 +1,652 @@
import torch
import csv
import os
import re
import json
import numpy as np
from settings import parse_args
from eda import *
from pretrain import *
from torch.utils.data import DataLoader
max_input_length_dict = {
'woz.en': 128,
'sst': 128,
'srl': 128,
'wikisql': 300,
'squad':512,
'ag':128,
'yelp':256,
'yahoo':128,
'dbpedia':128,
'amazon':200
}
max_ans_length_dict = {
'woz.en': 20,
'sst': 10,
'srl': 100,
'wikisql': 100,
'squad': 50,
'ag':10,
'yelp':10,
'yahoo':10,
'dbpedia':10,
'amazon':10
}
args = parse_args()
class LBDataset(torch.utils.data.Dataset):
def __init__(self, task_name, tokz, data_path, max_input_len=100, special_token_ids=None):
self.tokz = tokz
self.data_path = data_path
self.max_ans_len = args.max_ans_len
self.special_token_ids = special_token_ids
self.task_name = task_name
self.max_input_len = max_input_length_dict[self.task_name]
with open(data_path, "r", encoding='utf-8') as f:
ori_data = [json.loads(i) for i in f.readlines()]
data = []
for i in ori_data:
data += self.parse_example(i)
self.data = []
if len(data) > 0:
self.data = self.data_tokenization(task_name, data)
if task_name in ['wikisql','woz.en'] and 'test' in self.data_path:
with open(os.path.join(args.data_dir, task_name, 'answers.json'),'r') as f:
self.answers = json.load(f)
print(f'Number of answers is {len(self.answers)}')
if task_name in ['wikisql', 'squad'] and 'test' in self.data_path:
self.targets = []
for i in ori_data:
self.targets.append(i['output'])
def get_answer(self, target): # get answer text from intent
# replace - and _ to make the text seems more natural
if type(target) == list:
target = target[0]
if self.task_name == 'wikisql':
ans = target
else:
ans = target.replace("-", " ").replace("_", " ")
return ans
def parse_example(self, example):
if self.task_name in ['wikisql','woz.en','squad','srl']:
text = example['input']
question = example['question']
text_wo_question = example['text_wo_question']
else:
text = example['input'].replace('-',' ').replace('_',' ').replace('\\',' ')
question = example['question'].replace('-',' ').replace('_',' ').replace('\\',' ')
text_wo_question = example['text_wo_question'].replace('-',' ').replace('_',' ').replace('\\',' ')
ans = self.get_answer(example['output'])
if self.task_name in ['woz.en','squad','wikisql'] and 'test' in self.data_path:
ans_list = example['output']
return [(text, ans, question, text_wo_question, example['id'], ans_list)]
return [(text, ans, question, text_wo_question)]
def per_tokenization(self, task_name, d):
# print(d,flush=True)
# input_text, ans_text = d
if self.task_name in ['woz.en','squad','wikisql'] and 'test' in self.data_path:
input_text, ans_text, question, text_wo_question, idx, ans_list = d
else:
input_text, ans_text, question, text_wo_question = d
raw_text = text_wo_question
input_sen = task_name+':' + raw_text + '<QUES>' + question
context_sen = input_sen + '<ANS>' # for evaluation
ans_sen = ans_text
# all_sen = context_sen + ans_sen
all_sen = raw_text + '<QUES>'
res_dict = {
'all_sen': all_sen,
'input_sen': input_sen,
'context_sen': context_sen,
'ans_sen': ans_sen,
'question': question,
'raw_text': raw_text
}
if self.task_name in ['woz.en','squad','wikisql'] and 'test' in self.data_path:
res_dict['id'] = idx
res_dict['ans_list'] = ans_list
return res_dict
def data_tokenization(self, task_name, data):
# print('data',data[:10],flush=True)
data = [self.per_tokenization(task_name, i) for i in data]
return data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
def sort_by_index(self):
self.data.sort(key=lambda x: x['id'])
def get_indices(self):
return [d['id'] for d in self.data]
class ULBDataset(torch.utils.data.Dataset):
def __init__(self, task_name, tokz, data_path, max_input_len=100, special_token_ids=None):
self.tokz = tokz
self.data_path = data_path
self.max_ans_len = args.max_ans_len
self.special_token_ids = special_token_ids
self.task_name = task_name
self.max_input_len = max_input_length_dict[self.task_name]
with open(data_path, "r", encoding='utf-8') as f:
# data = [json.loads(i) for i in f.readlines()]
# data = [self.parse_example(i) for i in data] # [utter, label]
ori_data = [json.loads(i) for i in f.readlines()]
data = []
for i in ori_data:
data += self.parse_example(i)
self.data = []
if len(data) > 0:
self.data = self.data_tokenization(task_name, data)
if task_name in ['wikisql','woz.en'] and 'test' in self.data_path:
with open(os.path.join(args.data_dir, task_name, 'answers.json'),'r') as f:
self.answers = json.load(f)
def get_answer(self, target): # get answer text from intent
if type(target) == list:
target = target[0]
if self.task_name == 'wikisql':
ans = target
else:
ans = target.replace("-", " ").replace("_", " ")
return ans
def parse_example(self, example):
if self.task_name in ['wikisql','woz.en','squad','srl']:
text = example['input']
question = example['question']
text_wo_question = example['text_wo_question']
else:
text = example['input'].replace('-',' ').replace('_',' ').replace('\\',' ')
question = example['question'].replace('-',' ').replace('_',' ').replace('\\',' ')
text_wo_question = example['text_wo_question'].replace('-',' ').replace('_',' ').replace('\\',' ')
ans = self.get_answer(example['output'])
return [(text, ans, question, text_wo_question)]
def per_tokenization(self, task_name, d):
input_text, ans_text, question, text_wo_question = d
raw_text = text_wo_question
input_sen = task_name+':' + raw_text + '<QUES>' + question
context_sen = input_sen + '<ANS>' # for evaluation
ans_sen = ans_text
# all_sen = context_sen + ans_sen
all_sen = raw_text + '<QUES>'
return {
'all_sen': all_sen,
'input_sen': input_sen,
'context_sen': context_sen,
'ans_sen': ans_sen,
'question': question,
'raw_text': raw_text
}
def data_tokenization(self, task_name, data):
data = [self.per_tokenization(task_name, i) for i in data]
return data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
class MemoryDataset(ULBDataset):
def __init__(self, res):
self.max_ans_len = args.max_ans_len
self.max_input_len = args.max_input_len
self.data = res['batch_text'] #list of dict
class PinnedBatch:
def __init__(self, data):
self.data = data
def __getitem__(self, k):
return self.data[k]
def __setitem__(self, key, value):
self.data[key] = value
def pin_memory(self):
for k in self.data.keys():
if type(self.data[k])!=list:
self.data[k] = self.data[k].pin_memory()
return self
def __repr__(self):
return self.data.__repr__()
def __str__(self):
return self.data.__str__()
def items(self):
return self.data.items()
def keys(self):
return self.data.keys()
def pad_seq(seq, pad, max_len, pad_left=False):
if pad_left:
return [pad] * (max_len - len(seq)) + seq
else:
return seq + [pad] * (max_len - len(seq))
class PadBatchSeq:
def __init__(self, pad_id=0, tokz=None, task=None):
self.pad_id = pad_id
self.tokz = tokz
self.task = task
def __call__(self, batch):
if self.task is not None:
max_input_len = max_input_length_dict[self.task]
max_ans_len = max_ans_length_dict[self.task]
else:
max_input_len = 600
max_ans_len = args.max_ans_len
input_sen = [i['input_sen'] for i in batch]
context_sen = [i['context_sen'] for i in batch]
all_sen = [i['all_sen'] for i in batch]
if 'ans_sen' in batch[0]:
ans_sen = [i['ans_sen'] for i in batch]
ans_encoding = self.tokz(ans_sen, padding='longest', max_length=max_ans_len, truncation=True, return_tensors='pt')
raw_sen = [i['raw_text'] for i in batch]
ques_sen = [i['question'] for i in batch]
input_encoding = self.tokz(input_sen, padding='longest', max_length=max_input_len, truncation=True, return_tensors='pt')
all_encoding = self.tokz(all_sen, padding='longest', max_length=max_input_len, truncation=True, return_tensors='pt')
context_encoding = self.tokz(context_sen, padding='longest', max_length=max_input_len, truncation=True, return_tensors='pt')
raw_text_encoding = self.tokz(raw_sen, padding='longest', max_length=max_input_len, truncation=True, return_tensors='pt')
question_encoding = self.tokz(ques_sen, padding='longest', max_length=max_input_len, truncation=True, return_tensors='pt')
res = {}
res["input_id"], res['input_mask'] = input_encoding.input_ids, input_encoding.attention_mask
res["all_id"], res['all_mask'] = all_encoding.input_ids, all_encoding.attention_mask
res["context_id"], res['context_mask'] = context_encoding.input_ids, context_encoding.attention_mask
if 'ans_sen' in batch[0]:
res["ans_id"], res['ans_mask'] = ans_encoding.input_ids, ans_encoding.attention_mask
res["raw_id"], res['raw_mask'] = raw_text_encoding.input_ids, raw_text_encoding.attention_mask
res["ques_id"], res['ques_mask'] = question_encoding.input_ids, question_encoding.attention_mask
res['batch_text'] = batch
return PinnedBatch(res)
# * Information for different tasks.
if args.data_type == 'intent':
TASK2INFO = {
"banking": {
"dataset_class": LBDataset,
"dataset_folder": "banking",
"task_type": "CLS",
},
"clinc": {
"dataset_class": LBDataset,
"dataset_folder": "clinc",
"task_type": "CLS",
},
"hwu": {
"dataset_class": LBDataset,
"dataset_folder": "hwu",
"task_type": "CLS",
},
"atis": {
"dataset_class": LBDataset,
"dataset_folder": "atis",
"task_type": "CLS",
},
"tod": {
"dataset_class": LBDataset,
"dataset_folder": "FB_TOD_SF",
"task_type": "CLS",
},
"snips": {
"dataset_class": LBDataset,
"dataset_folder": "snips",
"task_type": "CLS",
},
"top_split1": {
"dataset_class": LBDataset,
"dataset_folder": "top_split1",
"task_type": "CLS",
},
"top_split2": {
"dataset_class": LBDataset,
"dataset_folder": "top_split2",
"task_type": "CLS",
},
"top_split3": {
"dataset_class": LBDataset,
"dataset_folder": "top_split3",
"task_type": "CLS",
}
}
elif args.data_type =='slot':
TASK2INFO = {
"dstc8": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "dstc8",
"task_type": "SlotTagging",
},
"restaurant8": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "restaurant8",
"task_type": "SlotTagging",
},
"atis": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "atis",
"task_type": "SlotTagging",
},
"mit_movie_eng": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "mit_movie_eng",
"task_type": "SlotTagging",
},
"mit_movie_trivia": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "mit_movie_trivia",
"task_type": "SlotTagging",
},
"mit_restaurant": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "mit_restaurant",
"task_type": "SlotTagging",
},
"snips": {
"dataset_class": PromptSlotTaggingDataset,
"dataset_folder": "snips",
"task_type": "SlotTagging",
}
}
elif args.data_type =='decanlp' or args.data_type == 'tc' or args.data_type == 'mix':
TASK2INFO = {
"ag": {
"dataset_class": LBDataset,
"dataset_folder": "ag",
"task_type": "tc",
},
'dbpedia':{
'dataset_class':LBDataset,
'dataset_folder':'dbpedia',
'task_type':'tc',
},
'yelp':{
'dataset_class':LBDataset,
'dataset_folder':'yelp',
'task_type':'tc',
},
'yahoo':{
'dataset_class':LBDataset,
'dataset_folder':'yahoo',
'task_type':'tc',
},
'snips':{
'dataset_class':LBDataset,
'dataset_folder':'snips',
'task_type':'tc',
},
'amazon':{
'dataset_class':LBDataset,
'dataset_folder':'amazon',
'task_type':'tc',
},
'woz.en':{
'dataset_class':LBDataset,
'dataset_folder':'woz.en',
'task_type':'decanlp',
},
'squad':{
'dataset_class':LBDataset,
'dataset_folder':'squad',
'task_type':'decanlp',
},
'sst':{
'dataset_class':LBDataset,
'dataset_folder':'sst',
'task_type':'decanlp',
},
'srl':{
'dataset_class':LBDataset,
'dataset_folder':'srl',
'task_type':'decanlp',
},
'wikisql':{
'dataset_class':LBDataset,
'dataset_folder':'wikisql',
'task_type':'decanlp',
}
}
def get_unlabel_data(path, task):
info = TASK2INFO[task]
# data_path = os.path.join(path, info['dataset_folder'],'unlabel_train.json')
if args.unlabel_ratio == -1:
num_unlabel = 2000
else:
num_unlabel = args.num_label*args.unlabel_ratio
data_path = os.path.join(path, info['dataset_folder'],str(num_unlabel)+'_unlabel_train.json')
return data_path
def get_unlabel_dict(path, task, tokz, args):
info = TASK2INFO[task]
special_token_ids = {'ans_token':tokz.convert_tokens_to_ids('ANS:')}
if args.data_type == 'intent':
return LBDataset(task, tokz, path, args.max_input_len, special_token_ids=special_token_ids)
else:
return None
def write_mix_train_file(label_train_dataset, unlabel_train_dataset, out_file, oridir):
datatype_list=['label_train','unlabel_train']
with open(out_file,'w') as fw:
for datatype in datatype_list:
datapath = os.path.join(oridir,datatype+'.json')
with open(datapath,'r') as f:
data = [json.loads(i) for i in f.readlines()]
for row in data:
# print(json.dumps(row['input'].strip('"'), ensure_ascii=False),file=fw)
print(row['input'.strip('"')], file=fw)
def create_dataloader_for_pretrain(mix_train_file, tokz, model, args):
data_files = {}
data_files['train'] = mix_train_file
extension = mix_train_file.split(".")[-1]
if extension == 'txt': extension = 'text'
datasets = load_dataset(extension, data_files=data_files)
train_dataset = datasets['train']
max_seq_length=128
def tokenize_function(examples):
return tokz(examples[text_column_name], return_attention_mask=False)
column_names = train_dataset.column_names
text_column_name = "text" if "text" in column_names else column_names[0]
tokenized_datasets = train_dataset.map(
tokenize_function,
batched=True,
remove_columns=column_names,
load_from_cache_file=True,
)
# T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
# To ensure that the input length is `max_seq_length`, we need to increase the maximum length
# according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
expanded_inputs_length, targets_length = compute_input_and_target_lengths(
inputs_length=max_seq_length,
noise_density=0.15,
mean_noise_span_length=3.0,
)
data_collator = DataCollatorForT5MLM(
tokenizer=tokz,
noise_density=0.15,
mean_noise_span_length=3.0,
input_length=max_seq_length,
target_length=targets_length,
pad_token_id=model.config.pad_token_id,
decoder_start_token_id=model.config.decoder_start_token_id,
)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {
k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= expanded_inputs_length:
total_length = (
total_length // expanded_inputs_length) * expanded_inputs_length
# Split by chunks of max_len.
result = {
k: [t[i: i + expanded_inputs_length]
for i in range(0, total_length, expanded_inputs_length)]
for k, t in concatenated_examples.items()
}
return result
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
# might be slower to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
train_dataset = tokenized_datasets.map(
group_texts,
batched=True,
load_from_cache_file=True,
)
train_sampler = torch.utils.data.RandomSampler(train_dataset)
mix_dataloader = DataLoader(train_dataset, batch_size=16,
sampler=train_sampler, num_workers=args.num_workers, pin_memory=True, collate_fn=data_collator)
return mix_dataloader
class MixedDataset(LBDataset):
def __init__(self, unlabel_data, label_data, tokz, max_input_len=100):
self.max_input_len = max_input_len
self.tokz = tokz
self.max_ans_len = args.max_ans_len
self.special_token_ids = {'ans_token':tokz.convert_tokens_to_ids('ANS:')}
self.data = []
self.data += unlabel_data
self.data += label_data
class AugDataset(LBDataset):
def __init__(self, task_name, tokz, label_data_path, unlabel_data_path, max_input_len=100, special_token_ids=None):
self.tokz = tokz
self.max_ans_len = args.max_ans_len
self.special_token_ids = special_token_ids
self.max_input_len = max_input_length_dict[task_name]
with open(label_data_path, "r", encoding='utf-8') as f:
ori_data = [json.loads(i) for i in f.readlines()]
with open(unlabel_data_path, "r", encoding='utf-8') as f:
ori_data += [json.loads(i) for i in f.readlines()]
data = []
for i in ori_data:
data += self.parse_example(i)
self.data = []
if len(data) > 0:
self.data = self.data_tokenization(task_name, data)
def parse_example(self, example):
text = example['userInput']['text']
aug_text_list = eda(text, num_aug=args.num_aug)
ans = self.get_answer(example['intent'])
res = [(aug_text, ans) for aug_text in aug_text_list]
res.append((text, ans))
return res
def get_datasets(path, tasks, tokz, max_input_len=100, special_token_ids=None):
res = {}
for task in tasks:
res[task] = {}
info = TASK2INFO[task]
if args.unlabel_ratio == -1:
num_unlabel = 2000
else:
num_unlabel = args.num_label * args.unlabel_ratio
max_input_len = max_input_length_dict[task]
# ! Remember to restore this.
res[task]['label_train'] = LBDataset(
task, tokz, os.path.join(path, info['dataset_folder'], 'label_train.json'), max_input_len=max_input_len, special_token_ids=special_token_ids)
if args.use_unlabel:
res[task]['unlabel_train'] = ULBDataset(
task, tokz, os.path.join(path, info['dataset_folder'], str(num_unlabel)+'_unlabel_train.json'), max_input_len=max_input_len, special_token_ids=special_token_ids)
res[task]['val'] = TASK2INFO[task]['dataset_class'](task, tokz, os.path.join(path, info['dataset_folder'], 'test.json'), max_input_len=max_input_len, special_token_ids=special_token_ids)
res[task]['test'] = TASK2INFO[task]['dataset_class'](task, tokz, os.path.join(path, info['dataset_folder'], 'test.json'), max_input_len=max_input_len, special_token_ids=special_token_ids)
return res
class PseudoDataset(torch.utils.data.Dataset):
def __init__(self, task, task_data, tokz, max_input_len=100, curr_data=None):
'''
task2data : {'task_name': [data1, data2, data3, ...]}
'''
self.max_input_len = max_input_length_dict[task]
self.tokz = tokz
self.max_ans_len = args.max_ans_len
self.data = []
# for task in task2data:
# self.data += self.data_tokenization(task, task2data[task])
self.data += self.data_tokenization(task, task_data)
if curr_data is not None:
self.data += curr_data
def per_tokenization(self, task_name, d):
# print(d,flush=True)
# input_text, ans_text = d
raw_text, ans_text, question = d
input_sen = task_name+':' + raw_text + '<QUES>' + question
context_sen = input_sen + '<ANS>' # for evaluation
ans_sen = ans_text
# all_sen = context_sen + ans_sen
all_sen = raw_text + '<QUES>'
return {
'all_sen': all_sen,
'input_sen': input_sen,
'context_sen': context_sen,
'ans_sen': ans_sen,
'question': question,
'raw_text': raw_text
}
def data_tokenization(self, task_name, data):
data = [self.per_tokenization(task_name, i) for i in data]
return data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]

472
ssll/unifymodel/generate.py Normal file
View File

@ -0,0 +1,472 @@
from transformers import T5Tokenizer, T5Config
import logging
import random
import torch
import numpy as np
from torch.utils.data import DataLoader
import json
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
import torch.distributed as dist
import os
import time
import gc
import json
import pickle
import argparse
import math
import re
import torch.nn as nn
import torch.utils.data as data
import torch.distributed as dist
import torch.multiprocessing as mp
import copy
import torch.nn.functional as F
t5tokz = T5Tokenizer.from_pretrained('t5-base')
# ! Get answers given inputs.
def get_answer(tokz, model, example, example_mask, max_ans_len, args=None, is_eval=False):
temperature = args.temperature
model.eval()
device = 'cuda'
eos_token = tokz.eos_token_id
pad_token = tokz.pad_token_id
bad_words = t5tokz.additional_special_tokens
bad_words_ids = tokz(bad_words, add_special_tokens=False).input_ids
# bad_words_ids = [[badid] for badid in tokz.additional_special_tokens_ids]
# * Get sequence outputs.
# print('Begin generating answers.',flush=True)
outputs = model.generate(input_ids=example, attention_mask=example_mask, do_sample=False, eos_token_id=eos_token, pad_token_id=pad_token,
output_scores=True, return_dict_in_generate=True, early_stopping=True, max_length=max_ans_len, bad_words_ids=bad_words_ids)
# print(outputs,flush=True)
output_seq = outputs.sequences
output_scores = outputs.scores
if not is_eval:
output_scores = torch.stack(list(output_scores), dim=0)
output_scores = output_scores.permute(1, 0, 2) # [80,10,50258]
# print(f'output_scores {output_scores.shape}',flush=True)
# print(f'output_seq {output_seq.shape}',flush=True)
# print('output_seq',output_seq.shape, 'output_scores',len(output_scores), output_scores[0].shape, flush=True)
# generate one: example [16, 29], out_seq [16, 39], out_score[0] [16, 50258], out_score [max_length=input_ids.shape[-1], batch_size, vocab_size] tuple
# generate many: out_seq [16, 5, 29], out_scores [16, 5, 10, 50258]
# return output_seq[:,example.shape[1]:], output_scores
return output_seq, output_scores
def textid_decode(text, eos, tokz, contain_eos=False):
if eos in text:
if contain_eos:
idx = text.index(eos)
text = text[:idx]
else:
text = text[:text.index(eos)]
text_id = text
len_text = len(text)
text = tokz.decode(text).strip()
return text, len_text, text_id
def padding_convert(text_list, eos):
tt_list = []
for text in text_list:
if eos in text:
eos_indexs = [i for i, x in enumerate(text) if x == eos]
# print('eos index', eos_indexs,flush=True)
if len(eos_indexs) > 1:
text = text[:eos_indexs[1]]
tt_list.append(text)
tt_lens = [len(i) for i in tt_list]
tt_lens = torch.tensor(tt_lens, dtype=torch.long).to('cuda')
tt_pad = torch.tensor([pad_seq(i, eos, max(tt_lens), pad_left=True)
for i in tt_list], dtype=torch.long).to('cuda')
return tt_pad, tt_lens
def gen_pseudo_data(model, task, tokz, max_output_len=90, batch_size=4, target_count=100, output_file=None, args=None, question=None):
device = 'cuda'
task_prefix = task+':'
task_prefix_id = tokz.encode(task_prefix)[0]
# task_prefix_id = tokz.encode('Task:')[0]
input_tokens = torch.full((batch_size,1), task_prefix_id).cuda()
adapter_name = task.replace('.','_')
model.set_active_adapters(adapter_name)
model = model.cuda()
pseudo_list = []
utter_set = set()
eos_token = tokz.eos_token_id
pad_token = tokz.pad_token_id
# bad_words = tokz.additional_special_tokens
# bad_words_ids = [[badid] for badid in tokz.additional_special_tokens_ids]
bad_words = t5tokz.additional_special_tokens
bad_words_ids = tokz(bad_words, add_special_tokens=False).input_ids
if output_file is None:
raise ValueError("Pseudo output file is not specified.")
if output_file is not None:
if not os.path.isdir(os.path.dirname(output_file)):
os.makedirs(os.path.dirname(output_file), exist_ok=True)
# * Generate pseudo samples.
input_set = set()
while len(pseudo_list) < target_count:
once_target_count = min(32, target_count)
with torch.no_grad():
model.eval()
outputs = model.generate(input_ids=input_tokens, do_sample=True, eos_token_id=eos_token, pad_token_id=pad_token,
output_scores=True, return_dict_in_generate=True, early_stopping=True, max_length=min(max_output_len, 256),
num_return_sequences=once_target_count, bad_words_ids=bad_words_ids)
output_seq = outputs.sequences
output_list = output_seq.tolist()
# print('output_list len',len(output_list), flush=True)
# print('batch size',batch_size, flush=True)
for i in range(batch_size):
# print("Test what's the first token:",output_list[i],flush=True)
if len(output_list[i])<=1: continue
output_id = output_list[i][1:]
# print(output_id, flush=True)
if eos_token in output_id:
output_id = output_id[:output_id.index(eos_token)]
output = tokz.decode(output_id, skip_special_tokens=True) #! Not contain '<QUES>' token
# print(output, flush=True)
if len(output)>=2:
input_sen = output.strip()
context_sen = task_prefix + input_sen + '<QUES>' + question + '<ANS>'
input_ids = torch.tensor([tokz.encode(context_sen)]).cuda()
ans_id = model.generate(input_ids=input_ids, do_sample=False, eos_token_id=eos_token, pad_token_id=pad_token, bad_words_ids=bad_words_ids)
output_sen = tokz.decode(ans_id[0], skip_special_tokens=True)
else:
input_sen, output_sen = None, None
if input_sen and output_sen:
print('INPUT::',input_sen, '===> OUTPUT::', output_sen, flush=True)
if input_sen not in input_set:
input_set.add(input_sen) # avoid duplicate inputs.
pseudo_list.append([input_sen, output_sen, question])
pseudo_list = pseudo_list[:target_count]
with open(output_file, 'w', encoding='utf8') as f:
for input, output, _ in pseudo_list:
print(json.dumps({'input': input, 'output': output}, ensure_ascii=False), file=f)
model.train()
return pseudo_list
def get_all_priors(model, tokz, args):
def pseudo_prompt(task):
return f"In the \"{task}\" task, which intent category best describes: \""
all_prior_info = {}
for task in args.tasks:
prompt_id = [tokz.bos_token_id]+tokz.encode(pseudo_prompt(task))+[tokz.eos_token_id]
prompt_id = torch.LongTensor(prompt_id).to('cuda')
prior_out = model.encoder(input_ids=prompt_id)
prior_emb, _ = model.avg_attn(prior_out[0])
prior_mean, prior_logvar = model.prior_mean(prior_emb), model.prior_logvar(prior_emb)
all_prior_info[task]=(prior_mean, prior_logvar)
return all_prior_info
def get_nearest_task(model, tokz, sample, all_prior_info, args):
def pseudo_prompt(task):
return f"In the \"{task}\" task, which intent category best describes: \""
all_posteriors={}
batch_size = len(sample['utter_id'])
for task in args.tasks:
prompt_id = [tokz.bos_token_id]+tokz.encode(pseudo_prompt(task))
bt_prompt_id = torch.LongTensor([prompt_id for _ in range(batch_size)]).to('cuda')
bt_px_id = torch.cat((bt_prompt_id,sample['utter_id'].to('cuda')),dim=1)
bt_px_id = bt_px_id.to('cuda')
if len(bt_px_id)!=batch_size:
raise ValueError('Tensor concatenate is wrong.')
# px_id = tokz.encode(prompt_id+sample['utter_id'].tolist())
# prompt_id = torch.LongTensor(prompt_id).to('cuda')
# px_id = torch.LongTensor(px_id).to('cuda')
post_out = model.encoder(input_ids=bt_px_id)
post_emb, _ = model.avg_attn(post_out[0])
post_mean, post_logvar = model.post_mean(post_emb), model.post_logvar(post_emb)
all_posteriors[task]=(post_mean, post_logvar)
min_kl = 1e10
res_task = args.tasks[0]
all_kl_dist = []
for task in all_prior_info.keys():
prior_mean, prior_logvar = all_prior_info[task]
post_mean, post_logvar = all_posteriors[task]
kl_dist = kl_divergence(post_mean, post_logvar, prior_mean, prior_logvar)
all_kl_dist.append(kl_dist)
if kl_dist < min_kl:
min_kl = kl_dist
res_task = task
# print(all_kl_dist,flush=True)
return res_task
def get_pred_context(tokz, pred_task_name, gt_task_name, sample):
new_list = []
for ss in sample['context_id'].tolist():
# print('1',len(ss),flush=True)
context = tokz.decode(ss)
# print('1',context,flush=True)
# new_context = text.replace(f'In the "{gt_task_name}" task, which intent category best describes: "', pred_task_name).strip()
new_context = re.sub(gt_task_name,pred_task_name,context)
# print('2',new_context,flush=True)
new_context_id = tokz.encode(new_context)
# print('2',len(new_context_id),flush=True)
new_list.append(new_context_id)
# if len(new_context_id)!=len(new_list[0]):
# raise ValueError('Lengths are not match in this batch.')
context_lens = [len(i) for i in new_list]
context_mask = torch.ByteTensor([[1] * context_lens[i] + [0] * (max(context_lens)-context_lens[i]) for i in range(len(context_lens))])
new_res = torch.tensor([pad_seq(i, tokz.eos_token_id, max(context_lens), pad_left=True) for i in new_list], dtype=torch.long).to('cuda')
new_lens = torch.tensor(context_lens,dtype=torch.long).to('cuda')
# new_res = torch.LongTensor(new_list).to('cuda',non_blocking=True)
return new_res, new_lens
def get_pseudo_labels_for_all(tokz, model, dataloader, task, sampling=False, args=None):
out_dir = os.path.join(args.output_dir, f'{task}_unlabel_pseudo.json')
out_dir_data = []
data_path = get_unlabel_data(args.data_dir, task)
with open(data_path, 'r', encoding='utf-8') as f:
all_data = [json.loads(i) for i in f.readlines()]
with torch.no_grad():
model.eval()
pred_ans_all = []
input_all = []
for i, data in enumerate(dataloader):
sample = data['context_id'].to('cuda')
sample_lens = data['context_lens'].to('cuda')
pred_ans = get_answer(tokz, model, sample, sample_lens, max_ans_len=args.max_ans_len, args=args)
pred_ans_all.append(pred_ans)
pred_ans_all = communicate_tensor(pred_ans_all, pad_token=tokz.eos_token_id).tolist()
for i in range(len(pred_ans_all)):
res = {}
res['userInput'] = {}
res['userInput']['text'] = all_data[i]['userInput']['text']
res['intent'] = tokz.decode(cut_eos(pred_ans_all[i], tokz.eos_token_id))
out_dir_data.append(res)
with open(out_dir, 'w', encoding='utf-8') as f:
for res in out_dir_data:
print(json.dumps(res), file=f)
model.train()
return None
def get_pseudo_labels_per_batch(tokz, model, data, sampling=False, args=None):
with torch.no_grad():
model.eval()
context = data['context_id'].to('cuda')
context_lens = data['context_lens'].to('cuda')
out_seqs, out_scores = get_answer(tokz, model, context, context_lens, max_ans_len=args.max_ans_len, args=args)
# print('out_seqs',out_seqs,'out_scores',out_scores,flush=True)
# Get all lens.
eos_token = tokz.eos_token_id
ans_lens = []
max_seq_len = out_seqs.size(-1)
btz = out_seqs.size(0)
out_seqs_list = out_seqs.tolist()
# print(out_seqs_list,flush=True)
all_ans_id = []
for i in out_seqs_list:
# print(i,flush=True)
_, seq_len, seq_id = textid_decode(i, eos_token, tokz, contain_eos=True)
ans_lens.append(seq_len) # Contain one EOS token.
all_ans_id.append(seq_id)
# print('ans len', ans_lens, flush=True)
# print('max_seq_len',max_seq_len, 'max(all_lens)',max(ans_lens), flush=True)
# ans_scores = [[out_scores[context_lens[i]:all_lens[i]]] for i in range(btz)]
# * Compute PPL.
ppl_batch = compute_ppl(out_seqs, out_scores, tokz)
# * Prepare batch for training.
spe_token_id = tokz.convert_tokens_to_ids('ANS:')
batch_list = []
for i in range(btz):
info = {}
input_id = data['input_id'][i][:data['input_lens'][i]]
input_id = input_id.tolist()
info['input_id'] = input_id
info['ans_id'] = all_ans_id[i] + [tokz.eos_token_id]
info['context_id'] = input_id + [spe_token_id]
info['all_id'] = info['context_id'] + all_ans_id[i] + [tokz.eos_token_id]
batch_list.append(info)
# Print some pseudo samples.
if i==0:
print('-'*10, 'Unlabeled sample with pseudo labels:', flush=True)
print('input:', tokz.decode(info['input_id']), flush=True)
print('all:',tokz.decode(info['all_id']), flush=True)
print('-'*20,flush=True)
# print(batch_list[0],flush=True)
padbatch = PadBatchSeq()
batch = padbatch(batch_list)
print('ppl',ppl_batch.tolist(),flush=True)
return ppl_batch, batch
def compute_ppl(out_seqs, out_scores, tokz):
# * Get answer lens.
ans_lens = []
max_seq_len = out_seqs.size(-1)
out_seqs_list = out_seqs.tolist()
all_ans_id = []
for i in out_seqs_list:
_, seq_len, seq_id = textid_decode(i, tokz.eos_token_id, tokz, contain_eos=True)
ans_lens.append(seq_len) # Contain one EOS token.
all_ans_id.append(seq_id)
btz = out_seqs.size(0)
ans_scores = out_scores # may get -inf
# print(f'answer scores max {torch.max(ans_scores)} min {torch.min(ans_scores)}')
# for j in range(btz):
# score_per_sample = []
# for l in range(max_seq_len):
# score_per_sample.append(out_scores[l][j].tolist())
# ans_scores.append(score_per_sample)
# ans_scores = torch.tensor(ans_scores)
# ans_scores = out_scores.permute(1, 0, 2)
log_softmax = nn.LogSoftmax(dim=-1) # for vocab_size
log_soft = log_softmax(ans_scores)
select_log_soft = []
log_soft_flatten = log_soft.view(-1, ans_scores.shape[-1]).to('cuda')
# log_soft_flatten = log_soft.view(-1, len(tokz)).to('cuda')
out_seq_flatten = out_seqs.contiguous().view(-1).to('cuda')
select_log_soft = log_soft_flatten[torch.arange(log_soft_flatten.shape[0]),out_seq_flatten]
select_log_soft = select_log_soft.view(btz, max_seq_len)
# ans_scores:[16, 10, 50258] log_soft [16, 10, 50258] select_log_soft [16, 10]
ans_score_mask = [[1]* ans_lens[i]+[0]*(max_seq_len-ans_lens[i]) for i in range(btz)]
ans_score_mask = torch.tensor(ans_score_mask).to('cuda')
neg_loglh = torch.sum(-select_log_soft * ans_score_mask, dim=1) / torch.sum(ans_score_mask, dim=1)
ppl_batch = torch.exp(neg_loglh).to('cuda')
return all_ans_id, ppl_batch
def old_compute_ppl(out_seqs, out_scores, tokz):
# * Get answer lens.
ans_lens = []
max_seq_len = out_seqs.size(-1)
out_seqs_list = out_seqs.tolist()
all_ans_id = []
for i in out_seqs_list:
_, seq_len, seq_id = textid_decode(i, tokz.eos_token_id, tokz, contain_eos=True)
ans_lens.append(seq_len) # Contain one EOS token.
all_ans_id.append(seq_id)
btz = out_seqs.size(0)
ans_scores = out_scores
# for j in range(btz):
# score_per_sample = []
# for l in range(max_seq_len):
# score_per_sample.append(out_scores[l][j].tolist())
# ans_scores.append(score_per_sample)
# ans_scores = torch.tensor(ans_scores)
# ans_scores = out_scores.permute(1, 0, 2)
log_softmax = nn.LogSoftmax(dim=-1) # for vocab_size
log_soft = log_softmax(ans_scores)
select_log_soft = []
log_soft_flatten = log_soft.view(-1, len(tokz)).to('cuda')
out_seq_flatten = out_seqs.contiguous().view(-1).to('cuda')
select_log_soft = log_soft_flatten[torch.arange(log_soft_flatten.shape[0]),out_seq_flatten]
select_log_soft = select_log_soft.view(btz, max_seq_len)
# ans_scores:[16, 10, 50258] log_soft [16, 10, 50258] select_log_soft [16, 10]
ans_score_mask = [[1]* ans_lens[i]+[0]*(max_seq_len-ans_lens[i]) for i in range(btz)]
ans_score_mask = torch.tensor(ans_score_mask).to('cuda')
neg_loglh = torch.sum(-select_log_soft * ans_score_mask, dim=1) / torch.sum(ans_score_mask, dim=1)
ppl_batch = torch.exp(neg_loglh).to('cuda')
return all_ans_id, ppl_batch
def map_to_same_length(seq_list, tokz):
max_seq_len = 10
eos_token = tokz.eos_token_id
new_seq_list = []
for i in range(len(seq_list)):
if len(seq_list[i]) < max_seq_len:
new_seq_list.append(seq_list[i]+[eos_token] * (max_seq_len-len(seq_list[i])))
return new_seq_list
def get_pseudo_labels_per_batch_with_sampling(tokz, model, data, args=None):
eos_token = tokz.eos_token_id
with torch.no_grad():
model.eval()
context = data['context_id'].to('cuda')
context_lens = data['context_lens'].to('cuda')
out_seqs, out_scores = get_answer(tokz, model, context, context_lens, max_ans_len=args.max_ans_len, args=args)
# generate many: out_seq [80, 29], out_scores [80, 10, 50258]
btz = args.train_batch_size
all_seqs = []
all_scores = [] # Save batch_size number.
all_ans_id = []
all_ppl = []
all_res = []
ans_id, ppl = compute_ppl(out_seqs, out_scores, tokz)
# output_seq = output_seq.split(args.pl_sample_num)
# output_seq = torch.stack(output_seq)
# output_scores = output_scores.split(args.pl_sample_num)
# output_scores = torch.stack(output_scores)
# out_seq [16, 5, 29], out_scores [16, 5, 10, 50258]
ppl = torch.stack(torch.split(ppl, args.pl_sample_num)) # [16, 5]
print('ppl shape after split', ppl.shape, flush=True)
for i in range(btz):
example_ppl = ppl[i,:]
# scores shape torch.Size([5, 10, 50258]) seq shape torch.Size([5, 10])
min_idx = torch.argmin(example_ppl)
all_ans_id.append(ans_id[min_idx+i*args.pl_sample_num])
all_ppl.append(example_ppl[min_idx])
all_ppl = torch.tensor(all_ppl)
# * Prepare batch for training.
spe_token_id = tokz.convert_tokens_to_ids('ANS:')
batch_list = []
for i in range(btz):
info = {}
input_id = data['input_id'][i][:data['input_lens'][i]]
input_id = input_id.tolist()
info['input_id'] = input_id
info['ans_id'] = all_ans_id[i] + [tokz.eos_token_id]
info['context_id'] = input_id + [spe_token_id]
info['all_id'] = info['context_id'] + all_ans_id[i] + [tokz.eos_token_id]
batch_list.append(info)
# Print some pseudo samples.
if i==0:
print('-'*10, 'Unlabeled sample with pseudo labels:', flush=True)
print('input:', tokz.decode(info['input_id']), flush=True)
print('all:',tokz.decode(info['all_id']), flush=True)
print('-'*20,flush=True)
# print(batch_list[0],flush=True)
padbatch = PadBatchSeq()
batch = padbatch(batch_list)
print('ppl', all_ppl.tolist(),flush=True)
if __name__ =='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--tasks',nargs='+')
parser.add_argument('--model_aug',default=False, type=bool)
parser.add_argument('--freeze_plm', default=False, type=bool)
args = parser. parse_args()
model_path = 'outputs/tc_semi_meantc_K100_ag17/ema_model_last.pt'
tokz_path = 'outputs/tc_semi_meantc_K100_ag17/model_cache/out'
t5adapter = 'outputs/tc_semi_meantc_K100_ag17/t5adapter'
tokz = T5Tokenizer.from_pretrained(tokz_path)
config = T5Config()
model = SSLLModel(config)
model = model.initialize(t5adapter, 'ag', args)
model = model.fresh('ag', args, ema=False)
model.load_state_dict(torch.load(model_path))
output_file = 'outputs/tc_semi_meantc_K100_ag17/pseudo_ag.json'
_ = gen_pseudo_data(model, 'ag', tokz, max_output_len=90, batch_size=30, target_count=10, output_file=output_file, args=None)

252
ssll/unifymodel/memory.py Normal file
View File

@ -0,0 +1,252 @@
import torch
import numpy as np
from eda import *
import torch.nn as nn
from unifymodel.dataset import *
class Curr_Unlabel_Memory(object):
def __init__(self,keys=None,values=None, questions=None,args=None):
if keys is None:
self.keys = []
self.values = []
self.questions = []
else:
self.keys = keys
self.values = values
self.questions = questions
self.total_keys = len(self.keys)
if args.unlabel_ratio == -1: # fix unlabel to 2k
num_unlabel = 2000
else:
num_unlabel = args.num_label * args.unlabel_ratio
self.memory_size = int(args.newmm_size * num_unlabel)
def push(self, keys, values, questions, args=None):
# print(f'Lengths: key {len(keys)}, values: {len(values)}, questions: {len(questions)}', flush=True)
for key, value, ques in zip(keys, values, questions):
if self.total_keys < self.memory_size:
assert type(self.keys) == list
# print(f'key {key}, keys {self.keys}', flush=True)
if key not in self.keys:
self.keys.append(key.tolist())
self.values.append(value)
self.questions.append(ques)
else:
drop_idx = random.choice(range(self.memory_size))
self.keys.pop(drop_idx)
self.values.pop(drop_idx)
self.questions.pop(drop_idx)
self.keys.append(key.tolist())
self.values.append(value)
self.questions.append(ques)
self.total_keys = len(self.keys) # Update the memory size.
def create_batch_to_augment_memory(old_task, old_memory, curr_memory, tokz=None, args=None):
querys = torch.tensor(old_memory['keys'])
questions = old_memory['questions']
neighbors = get_neighbors(querys, task_name=None, K=args.back_kneighbors, memory=curr_memory, args=args, questions=questions)
aug_unlabel_batch = create_batch_from_memory(neighbors, tokz, args=args, task_name=old_task)
return aug_unlabel_batch
def cosine_similarity(v1, m2):
# print(v1.shape, m2.shape)
if len(m2.shape) == 1 and len(v1.shape) == 1:
cos = nn.CosineSimilarity(dim=0)
elif len(m2.shape)>1:
v1 = v1.unsqueeze(0)
cos = nn.CosineSimilarity(dim=1)
else:
print(f'v1 shape {v1.shape}, m2 shape {m2.shape}',flush=True)
score = cos(v1, m2)
return score
def get_sentence_embedding(model, batch, args=None):
model.set_active_adapters(None)
batch_size = batch['raw_id'].size()[0]
input_tokens = batch['raw_id'].cuda()
attn_masks = batch['raw_mask'].cuda()
outputs = model.transformer.encoder(input_ids=input_tokens, attention_mask=attn_masks, return_dict=True)
pooled_emb = outputs.last_hidden_state
sen_emb = torch.mean(pooled_emb, dim=1) # (batch_size, 768)
# print('sentence emb shape', sen_emb.shape, flush=True)
# if args.diff_question:
all_questions = []
all_values = []
all_keys = []
for i in range(batch_size):
sample_ques = batch['batch_text'][i]['question']
sample_value = batch['batch_text'][i]['raw_text']
all_questions.append(sample_ques)
all_values.append(sample_value)
# all_keys.append(sen_emb[i].tolist())
all_keys.append(sen_emb[i,:])
# return sen_emb, all_values, all_questions
return all_keys, all_values, all_questions
def construct_memory(task_name, model, batch, memory=None, args=None):
sen_embs, all_values, all_questions = get_sentence_embedding(model, batch, args=args)
if memory is None:
memory = {}
if task_name not in memory:
memory[task_name] = {}
memory[task_name]['keys'] = []
memory[task_name]['values'] = []
memory[task_name]['questions'] = []
batch_size = len(batch['batch_text'])
for i in range(batch_size):
key = sen_embs[i].tolist()
# value = batch['raw_id'][i] #
value = batch['batch_text'][i]['raw_text'] # * Store the raw sentence as the value into the memory.
memory[task_name]['keys'].append(key)
memory[task_name]['values'].append(value)
memory[task_name]['questions'].append(all_questions[i])
return memory
def get_neighbors(querys, task_name=None, K=1, memory=None, args=None, questions=None):
assert memory is not None
if task_name is not None: #* forward augment, memory is old unlabel memory
keys = memory[task_name]['keys']
values = memory[task_name]['values']
# questions = memory[task_name]['questions'] # forward augment needs current unlabel questions
else: #* backward augment, memory is current_unlabel_memory
keys = memory.keys
values = memory.values
# print(f'Number of the memory content is {len(keys)}',flush=True)
assert questions is not None # should be old memory question
# if args.diff_question and questions is None:
# questions = memory.questions
# questions = memory[task_name]['questions']
keys_tensor = torch.tensor(keys).cuda()
retrieved_samples = []
for q in range(len(querys)):
query = querys[q]
query = query.cuda()
# print('query shape', query.shape, 'key shape', keys_tensor.shape, flush=True)
# similarity_scores = torch.matmul(keys_tensor, query.T)
similarity_scores = cosine_similarity(query, keys_tensor)
top_k = torch.topk(similarity_scores, K)
top_k_idxs = top_k.indices.tolist()
if top_k.values[-1] <= args.similarity_tau:
continue
# print('Top K largest consine similarity scores:', top_k.values.tolist(), flush=True)
K_neighbor_keys = [keys[i] for i in top_k_idxs]
neighbors = [values[i] for i in top_k_idxs]
# retrieved_samples.append(neighbors)
retrieved_samples.append((neighbors, questions[q])) # neighbors is a list of unlabeled inputs, questions[q] is a sentence.
return retrieved_samples
def create_batch_from_memory(samples, tokz, args, task_name):
input_sen, context_sen, all_sen = [], [], []
task_prefix = task_name+':'
raw_sen, ques_sen = [], []
if args.diff_question:
for neighbors, question in samples:
for sen in neighbors:
# * Augments retrieved neighbors
if args.aug_neighbors_from_memory:
aug_text_list = eda(sen, num_aug=args.num_aug)
for aug_text in aug_text_list:
input_text = task_prefix + aug_text + '<QUES>' + question
context_sen.append(input_text+'<ANS>')
all_sen.append(aug_text + '<QUES>')
raw_sen.append(aug_text)
ques_sen.append(question)
# all_sen.append(aug_text+'<ANS>'+ans_text)
# ans_sen.append(ans_text)
# * Directly use retrieved neighbors
else:
input_text = task_prefix + sen + '<QUES>' + question
context_text = input_text+'<ANS>'
all_text = sen + '<QUES>'
input_sen.append(input_text)
context_sen.append(context_text)
all_sen.append(all_text)
raw_sen.append(sen)
ques_sen.append(question)
else:
for sen in samples:
# * Augments retrieved neighbors
if args.aug_neighbors_from_memory:
aug_text_list = eda(sen, num_aug=args.num_aug)
for aug_text in aug_text_list:
input_text = task_prefix + aug_text + '<QUES>' + question
context_sen.append(input_text+'<ANS>')
all_sen.append(aug_text + '<QUES>')
raw_sen.append(aug_text)
ques_sen.append(question)
# all_sen.append(aug_text+'<ANS>'+ans_text)
# ans_sen.append(ans_text)
# * Directly use retrieved neighbors
else:
input_text = task_prefix + sen + '<QUES>' + question
context_text = input_text+'<ANS>'
all_text = sen + '<QUES>'
input_sen.append(input_text)
context_sen.append(context_text)
all_sen.append(all_text)
raw_sen.append(sen)
ques_sen.append(question)
batch = []
batch_list = []
if len(input_sen)>=1:
for l in range(len(input_sen)):
sample_dict = {}
sample_dict['all_sen'] = all_sen[l]
sample_dict['input_sen'] = input_sen[l]
sample_dict['context_sen'] = context_sen[l]
sample_dict['question'] = ques_sen[l]
sample_dict['raw_text'] = raw_sen[l]
batch.append(sample_dict)
# input_encoding = tokz(input_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
# all_encoding = tokz(all_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
# context_encoding = tokz(context_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
# raw_text_encoding = tokz(raw_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
# question_encoding = tokz(ques_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
res = {}
# res["input_id"], res['input_mask'] = input_encoding.input_ids, input_encoding.attention_mask
# res["all_id"], res['all_mask'] = all_encoding.input_ids, all_encoding.attention_mask
# res["context_id"], res['context_mask'] = context_encoding.input_ids, context_encoding.attention_mask
# res["raw_id"], res['raw_mask'] = raw_text_encoding.input_ids, raw_text_encoding.attention_mask
# res["ques_id"], res['ques_mask'] = question_encoding.input_ids, question_encoding.attention_mask
res['batch_text'] = batch
else:
res = None
return res
def get_old_center_dict(old_memory, prev_tasks, args=None):
center_dict = {}
for prev_task in prev_tasks:
old_keys = old_memory[prev_task]['keys'] # list of list
old_keys_tensor = torch.tensor(old_keys).cuda()
old_center = torch.mean(old_keys_tensor,dim=0)
center_dict[prev_task] = old_center
return center_dict
def adapter_selection(prev_tasks, curr_center, old_center_dict, args=None):
distance_list = []
for prev_task in prev_tasks:
dist = cosine_similarity(curr_center, old_center_dict[prev_task])
distance_list.append(dist)
print(f'Distances among centers {distance_list}',flush=True)
max_dist = max(distance_list)
max_idx = distance_list.index(max_dist)
adapter_name = prev_tasks[max_idx]
return adapter_name

41
ssll/unifymodel/model.py Normal file
View File

@ -0,0 +1,41 @@
import torch
from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration, PreTrainedModel, T5Model
from transformers.adapters import T5AdapterModel
from torch import nn
class SSLLModel(T5AdapterModel):
# def __init__(self, config, args=None):
# super(SSLLModel, self).__init__(config)
# self.args = args
# self.model = T5AdapterModel(config)
# self.model.transformer is T5Model
def initialize(self, model_path, task_name, args=None):
self = self.from_pretrained('t5-base', cache_dir=model_path)
self.config.vocab_size = self.config.vocab_size + len(args.tasks) + 2 # Add task-specific special tokens. 2 for <ANS> and <QUES>
self.transformer.resize_token_embeddings(self.config.vocab_size)
# print('old config',self.config,flush=True)
# print('decoder_start_token_id',self.config.decoder_start_token_id, flush=True)
if args.model_aug:
self.config.dropout_rate = args.dropout_rate
return self
def fresh(self, task_name, args=None, ema=False, initial=True):
# * Add task-specific adapters.
# self.add_adapter('adapter'+task_name)
# self.add_seq2seq_lm_head('lm_head'+task_name)
# self.train_adapter('adapter'+task_name)
task_name = task_name.replace('.','_')
if initial:
self.add_adapter(task_name)
self.add_seq2seq_lm_head(task_name)
self.set_active_adapters(task_name)
self.train_adapter(task_name) # Freeze other model params.
if not args.freeze_plm:
self.freeze_model(False) # Finetuning all parameters. Default is True.
# print('new config',self.config,flush=True)
# print('decoder_start_token_id',self.config.decoder_start_token_id, flush=True)
if ema:
for param in self.parameters():
param.detach_()
return self

940
ssll/unifymodel/trainer.py Normal file
View File

@ -0,0 +1,940 @@
from unifymodel.utils import *
from unifymodel.generate import *
from unifymodel.model import SSLLModel
from unifymodel.dataset import PadBatchSeq, TASK2INFO, LBDataset, get_datasets, get_unlabel_data, get_unlabel_dict, MixedDataset
from unifymodel.dataset import *
from unifymodel.memory import *
from transformers import T5Config, T5Tokenizer, T5Model,T5ForConditionalGeneration, T5AdapterModel
from torch.utils.data import DataLoader
import torch.distributed as dist
import os, time, gc, json, pickle, argparse, math, threading, shutil
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.nn import DataParallel
import torch.distributed as dist
import torch.multiprocessing as mp
import numpy as np
import transformers
from transformers import get_linear_schedule_with_warmup, Conv1D, AdamW
from torch.utils.tensorboard import SummaryWriter
import importlib
import copy
from collections import Counter
from rouge import Rouge
from metrics import *
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
AutoConfig,
AutoTokenizer,
BatchEncoding,
AutoModelForMaskedLM,
FlaxT5ForConditionalGeneration,
HfArgumentParser,
PreTrainedTokenizerBase,
T5Config,
is_tensorboard_available,
set_seed,
T5ForConditionalGeneration,
Trainer,
TrainingArguments,
is_torch_tpu_available,
)
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
from pretrain import *
def compute_solver_loss(model, loss_fn, total, args=None, kl_loss=None):
input_tokens, input_mask, target_tokens, target_mask = total
model = model.cuda()
model.train()
outputs = model(input_ids=input_tokens, attention_mask=input_mask, labels=target_tokens)
loss = outputs[0]
return loss
def compute_lm_loss(model, loss_fn, total, args=None, kl_loss=None):
input_tokens, decoder_input_ids, labels = total
model = model.cuda().train()
decoder_input_ids = None
outputs = model(input_ids=input_tokens, decoder_input_ids=decoder_input_ids, labels=labels)
loss = outputs[0]
return loss
def compute_feedback_loss(model, ema_model, loss_fn, total, args=None, kl_loss=None):
input_tokens, input_mask, target_tokens, target_mask = total
model, ema_model = model.cuda(), ema_model.cuda()
model.train()
ema_model.train()
# outputs = model(input_ids=input_tokens, attention_mask=input_mask, labels=target_tokens)
# logits = outputs.logits.contiguous()
# logits = logits.view(-1, logits.size(-1))
ema_outputs = ema_model(input_ids=input_tokens, attention_mask=input_mask, labels=target_tokens)
# ema_logits = ema_outputs.logits.contiguous().detach()
# ema_logits = ema_logits.view(-1, ema_logits.size(-1))
# target_mask = (target_mask>0)
# feedback_loss = kl_loss(ema_logits, logits, label_mask=target_mask) # Student model give feedback to teacher model.
feedback_loss = ema_outputs.loss
print('feedback_loss',feedback_loss.item(), flush=True)
return feedback_loss
def compute_consistency_loss(model, ema_model, loss_fn, total, args=None, kl_loss=None, tokz=None, epoch=0, task_name=None):
input_tokens, input_mask, target_tokens, target_mask = total
model, ema_model = model.cuda(), ema_model.cuda()
model.train()
ema_model.train()
btz = input_tokens.size()[0]
# print(f'length of tokz {len(tokz)}', flush=True)
assert task_name is not None
max_ans_len = max_ans_length_dict[task_name]
# * Teacher prediction
ema_seqs, ema_scores = get_answer(tokz, ema_model, input_tokens, input_mask, max_ans_len, args=args, is_eval=False)
# print('ema_seqs shape',ema_seqs.shape,'input shape',input_tokens.shape,flush=True)
# ema_seqs [192, 20], target [192, 4], input [192, 119]
# print(f'length of tokz {len(tokz)}', flush=True)
ema_target_tokens = ema_seqs[:,1:]
# print(f'ema tgt token 0 {ema_target_tokens[0]}',flush=True)
ema_target_tokens[ema_target_tokens==tokz.pad_token_id] = -100
# print(f'ema tgt token 1 {ema_target_tokens[0]}',flush=True)
ema_target_tokens = ema_target_tokens.contiguous()
ema_target_mask = torch.ones_like(ema_target_tokens)
ema_target_mask[ema_target_tokens==-100] = 0
ema_target_mask = (ema_target_mask > 0)
if args.add_confidence_selection:
_, ppl = compute_ppl(ema_seqs[:,1:], ema_scores, tokz)
ppl = torch.stack(torch.split(ppl, args.pl_sample_num)) # [16, 5]
all_ppl = []
for i in range(btz):
example_ppl = ppl[i,:]
# scores shape torch.Size([5, 10, 50258]) seq shape torch.Size([5, 10])
min_idx = torch.argmin(example_ppl)
all_ppl.append(example_ppl[min_idx])
ppl_batch = torch.tensor(all_ppl)
# print(f'ppl {ppl_batch}', flush=True)
# * Select high confidence outputs.
greater_thresh = torch.le(ppl_batch, args.pseudo_tau).to('cuda')
# print(f'confidence mask {greater_thresh.shape}',flush=True) # [48] num_aug * batch_size
# print(f'ema tgt tokens {ema_target_tokens.shape}',flush=True)
ema_target_tokens = (ema_target_tokens.T * greater_thresh).T
ema_target_tokens[ema_target_tokens==tokz.pad_token_id] = -100
# * Student prediction
outputs = model(input_ids=input_tokens, attention_mask=input_mask, labels=ema_target_tokens)
logits = outputs.logits.contiguous()
logits = logits.view(-1, logits.size(-1))
consist_loss = outputs.loss
consistency_weight = get_current_consistency_weight(args, epoch)
if args.rdrop:
outputs2 = model(input_ids=input_tokens, attention_mask=input_mask, labels=ema_target_tokens)
logits2 = outputs2.logits.contiguous()
logits2 = logits2.view(-1, logits2.size(-1))
rdrop_loss = 0.5 * kl_loss(logits, logits2, label_mask=ema_target_mask) + 0.5 * kl_loss(logits2, logits, label_mask=ema_target_mask)
else:
rdrop_loss = 0.0
consistency_loss = consistency_weight * (consist_loss + rdrop_loss)
print('consistency_kl_loss:',consistency_loss.item(),flush=True)
# print(f'rdrop_loss: {rdrop_loss.item()}', flush=True)
return consistency_loss
def get_model_inputs(data, global_step=10, tokz=None, model_type='cls', task=None, data_type='label'):
if model_type=='cls':
input_tokens = data['context_id']
input_mask = data['context_mask'].contiguous()
if global_step < 3:
print('-'*10, 'Solver Inputs and targets', '-'*10, flush=True)
print('input:', tokz.decode(input_tokens[0]), flush=True)
print('input id:', input_tokens[0], flush=True)
print('attention_mask:', input_mask[0], flush=True)
if data_type == 'label':
target_tokens = data['ans_id']
target_mask = data['ans_mask'].contiguous()
if global_step < 3:
print('target:',tokz.decode(target_tokens[0]),flush=True)
print('target id:', target_tokens[0],flush=True)
print('\n',flush=True)
target_tokens[target_tokens==tokz.pad_token_id] = -100
target_tokens = target_tokens.contiguous()
input_total = (input_tokens.cuda(), input_mask.cuda(), target_tokens.cuda(), target_mask.cuda())
else:
input_total = (input_tokens.cuda(), input_mask.cuda(), None, None)
return input_total
elif model_type=='lm':
batch_size = data['input_id'].shape[0]
assert task is not None
task_prefix_id = tokz.encode(task+':')[0]
# print(task_prefix_id, flush=True)
input_tokens = torch.full((batch_size,1), task_prefix_id)
decoder_input_ids = data['all_id'][:,:-1].contiguous()
# labels = data['all_id'][:,1:].clone().detach()
labels = data['all_id'][:,:].clone().detach()
if global_step < 3:
print('-'*10,'LM Inputs and targets','-'*10, flush=True)
print('input:',tokz.decode(input_tokens[0]),flush=True)
print('input id:',input_tokens[0],flush=True)
print('target:',tokz.decode(labels[0]),flush=True)
print('target id:', labels[0],flush=True)
print('\n',flush=True)
labels[labels==tokz.pad_token_id] = -100
input_total = (input_tokens.cuda(), decoder_input_ids.cuda(), labels.cuda())
return input_total
elif model_type=='pretrain':
input_ids = data['input_ids'].cuda()
labels = data['labels'].cuda()
if global_step <= 3:
print('pretrain_input: ',tokz.decode(input_ids[0],flush=True))
print('pretrain_label: ',tokz.decode(labels[0],flush=True))
labels[labels==tokz.pad_token_id] = -100
labels = labels.contiguous()
decoder_input_ids = data['decoder_input_ids'].cuda()
return input_ids, labels, decoder_input_ids
def update_ema_variables(model, ema_model, alpha, global_step):
# Use the true average until the exponential average is more correct
# alpha = ema_decay
alpha = min(1 - 1 / (global_step + 1), alpha)
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
ema_param.data = ema_param.data.cuda()
param.data = param.data.cuda()
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)
# ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
class Trainer:
def __init__(self, args, tokz, datasets, logger, cache_dir, memory):
# Other
self.args = args
self.datasets = datasets
self.logger = logger
self.rank = self.args.local_rank
self.device = self.args.device
self.global_step = 0
self.task_step = 0
distributed = False if self.rank == -1 else True
self.task_dict = {k: v for v, k in enumerate(self.args.tasks)}
# Tensorboard log
if self.rank in [0, -1]: # Only write logs on master node
self.train_writer = SummaryWriter(os.path.join(args.tb_log_dir, 'train'), flush_secs=10)
self.valid_writer = SummaryWriter(os.path.join(args.tb_log_dir, 'valid'))
self.config = T5Config()
# print('T5 Config:',self.config,flush=True)
self.tokz = tokz
print('len tokz of t5', len(self.tokz), flush=True)
self.cache_dir = cache_dir
self.save_folder = self.args.output_dir
os.makedirs(self.save_folder, exist_ok=True)
# randomness
np.random.seed(args.seed)
prng = np.random.RandomState()
torch.random.manual_seed(args.seed)
gpu = not self.args.no_gpu
if gpu: torch.cuda.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed)
self.loss_fn = nn.CrossEntropyLoss(reduction='none').to(self.device, non_blocking=True)
self.kl_loss = KDLoss(KD_term=self.args.KD_term, T=self.args.KD_temperature)
# * Load datasets.
self.data_loaders = {}
for task in self.datasets:
self.data_loaders[task] = {}
train_sampler = torch.utils.data.RandomSampler(datasets[task]['label_train'])
valid_sampler = None
self.data_loaders[task]['label_train'] = DataLoader(datasets[task]['label_train'], batch_size=self.args.train_batch_size,
sampler=train_sampler, num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=task))
self.data_loaders[task]['val'] = DataLoader(datasets[task]['val'], batch_size=self.args.eval_batch_size,
sampler=valid_sampler, num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=task))
self.data_loaders[task]['test'] = DataLoader(datasets[task]['test'], batch_size=self.args.eval_batch_size,
sampler=None, num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=None))
def _train(self, prev_model, prev_ema_model, curr_task, prev_tasks, all_tasks_res, question_dict):
if len(prev_tasks)!=0:
last_adapter_name = prev_tasks[-1].replace('.','_')
curr_adapter_name = curr_task.replace('.','_')
if prev_model is None:
assert len(prev_tasks) == 0
t5dir = os.path.join(self.args.output_dir,'t5basedir')
t5adapter = os.path.join(self.args.output_dir,'t5adapter')
model = SSLLModel(self.config)
model = model.initialize(t5adapter, curr_task, self.args)
model = model.fresh(curr_adapter_name, self.args, ema=False, initial=True)
model.train()
if self.args.meantc:
ema_model = SSLLModel(self.config)
ema_model = ema_model.initialize(t5adapter, curr_task, self.args)
ema_model = ema_model.fresh(curr_adapter_name, self.args, ema=True, initial=True)
ema_model = ema_model.cuda()
ema_model.train()
ema_num_tunable_params = sum(p.numel() for p in ema_model.parameters() if p.requires_grad)
assert ema_num_tunable_params is 0
# print('Total number of tunale parameters is %d'%num_tunable_params, flush=True)
self.logger.info("Successfully initialize the model with pretrained T5 model!")
self.logger.info(f'Tokenizer additional specitial token ids {self.tokz.additional_special_tokens_ids}')
else:
if self.args.random_initialization:
model = prev_model.fresh(curr_adapter_name, self.args, ema=False, initial=True)
model.train()
if self.args.meantc:
ema_model = prev_ema_model.fresh(curr_adapter_name, self.args, ema=True, initial=True)
ema_model = ema_model.cuda()
ema_model.train()
ema_num_tunable_params = sum(p.numel() for p in ema_model.parameters() if p.requires_grad)
assert ema_num_tunable_params is 0
else:
model = prev_model
curr_adapter_name = curr_task.replace('.','_')
model.load_state_dict(prev_ema_model.state_dict())
model.load_adapter(os.path.join(self.save_folder,last_adapter_name+'_adapter'), load_as=curr_adapter_name)
model = model.fresh(curr_adapter_name, self.args, ema=False, initial=False)
model.train()
# ema_model = SSLLModel(self.config)
ema_model = copy.deepcopy(prev_ema_model)
ema_model.load_state_dict(prev_ema_model.state_dict())
ema_model.load_adapter(os.path.join(self.save_folder,last_adapter_name+'_adapter'), load_as=curr_adapter_name)
ema_model = ema_model.fresh(curr_adapter_name, self.args, ema=True, initial=False)
ema_num_tunable_params = sum(p.numel() for p in ema_model.parameters() if p.requires_grad)
assert ema_num_tunable_params is 0
# ema_model.train()
if self.args.backward_augment:
curr_unlabel_memory = Curr_Unlabel_Memory(args=self.args)
model = model.to(self.device, non_blocking=True)
print('model.config',model.config,flush=True)
# ! Generate psueo data from previous model
if self.args.gen_replay and len(prev_tasks)>=1:
prev_train_dict = {}
if self.args.use_unlabel:
pseudo_data_count = int((len(self.datasets[curr_task]['label_train']) + len(self.datasets[curr_task]['unlabel_train'])) * self.args.pseudo_data_ratio) // len(prev_tasks)
else:
pseudo_data_count = int(len(self.datasets[curr_task]['label_train']) * self.args.pseudo_data_ratio) // len(prev_tasks)
self.logger.info(f'Will generate {pseudo_data_count} pseudo samples of previous tasks.')
inferred_data = {}
if self.args.construct_memory:
memory = {}
else: memory = None
for prev_task in prev_tasks:
pseudo_output_file = os.path.join(self.args.output_dir, curr_task+'_pseudo_'+prev_task+'.json')
self.logger.info('Generated pseudo will be writen into '+str(pseudo_output_file))
prev_data = gen_pseudo_data(prev_ema_model, prev_task, tokz=self.tokz, max_output_len=max_input_length_dict[prev_task],
batch_size=4, target_count=pseudo_data_count,
output_file=pseudo_output_file, args=self.args, question=question_dict[prev_task])
inferred_data[prev_task] = prev_data
self.logger.info(f'Finish generating pseudo data for {len(prev_tasks)}.')
self.logger.info(f'Inferring pseudo data from {prev_task}')
for i in range(0, min(6, pseudo_data_count)):
self.logger.info(f'PSEUDO: {prev_data[i]}')
prev_dataset = PseudoDataset(prev_task, prev_data, tokz=self.tokz, max_input_len=self.args.max_input_len)
prev_dataloader = DataLoader(prev_dataset, batch_size=self.args.train_batch_size, sampler=torch.utils.data.RandomSampler(prev_dataset),
num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=prev_task))
prev_train_dict[prev_task] = prev_dataloader
# * Construct memory for pseudo data.
if self.args.construct_memory:
for idx, prev_batch in enumerate(prev_dataloader):
memory = construct_memory(prev_task, model, prev_batch, memory, args=self.args)
# prev_train_dataset = PseudoDataset(inferred_data, tokz=self.tokz, max_input_len=self.args.max_input_len)
#! Zero-shot evaluation for new tasks
if len(prev_tasks)>=1 and self.args.evaluate_zero_shot:
self.logger.info(f'Evaluate the new task <<{curr_task}>> zero-shot with the previously learned model.')
if self.args.select_adapter and self.args.construct_memory:
test_keys_list = []
for _ in range(3):
curr_test_loader = self.data_loaders[curr_task]['test']
curr_test_batch = next(iter(curr_test_loader))
test_keys, _, _ = get_sentence_embedding(model, curr_test_batch, args=self.args) # keys: list of tensor
test_keys_list += test_keys
curr_center = torch.mean(torch.stack(test_keys_list), dim=0)
old_center_dict = get_old_center_dict(memory, prev_tasks, args=self.args)
initial_adapter_name = adapter_selection(prev_tasks, curr_center, old_center_dict, args=self.args)
eval_adapter_name = initial_adapter_name.replace('.','_')
self.logger.info(f'We select the old adapter {initial_adapter_name} to initialize the current task {curr_task}!')
else:
eval_adapter_name = curr_adapter_name
initial_score = self._eval_score(model, curr_task, out_dir=os.path.join(self.args.output_dir, curr_task) if self.args.log_eval_res else None, last=True, use_adapter_name=eval_adapter_name)
all_tasks_res.append({curr_task+'_initial':{'score':float(initial_score['score'])}})
if eval_adapter_name != curr_adapter_name:
eval_ema_path = os.path.join(self.save_folder, str(initial_adapter_name)+'_ema_model'+'.pt')
# model = torch.load(eval_ema_path)
model.load_adapter(os.path.join(self.save_folder,eval_adapter_name+'_adapter'), load_as=curr_adapter_name)
model = model.fresh(curr_adapter_name, self.args, ema=False, initial=False)
model.train()
# ema_model = SSLLModel(self.config)
# ema_model = copy.deepcopy(prev_ema_model)
# ema_model.load_state_dict(prev_ema_model.state_dict())
# ema_model = torch.load(eval_ema_path)
ema_model.load_adapter(os.path.join(self.save_folder,eval_adapter_name+'_adapter'), load_as=curr_adapter_name)
ema_model = ema_model.fresh(curr_adapter_name, self.args, ema=True, initial=False)
ema_num_tunable_params = sum(p.numel() for p in ema_model.parameters() if p.requires_grad)
assert ema_num_tunable_params is 0
self.logger.info(f'Reinitialize the model name~')
self.logger.info("Begin training!"+ "=" * 40)
self.logger.info(f'Currently training {curr_task}'+'-'*10)
evalout_dir = os.path.join(self.args.output_dir, curr_task)
if os.path.exists(evalout_dir) and os.path.isdir(evalout_dir):
shutil.rmtree(evalout_dir)
os.makedirs(evalout_dir, exist_ok=True)
# Optimizer weight decay
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': self.args.weight_decay},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
num_tunable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Total number of tunale parameters is %d'%num_tunable_params, flush=True)
optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.lr, correct_bias=True)
# optimizer = Adafactor(params=filter(lambda p: p.requires_grad, model.parameters()), lr=self.args.lr, weight_decay=self.args.weight_decay,
# scale_parameter=False, relative_step=False)
optimizer.zero_grad()
if self.args.backward_augment:
aug_optimizer =AdamW(optimizer_grouped_parameters, lr=self.args.lr, correct_bias=True)
aug_optimizer.zero_grad()
if self.args.noisy_stu or self.args.online_noisy_stu:
stu_optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.lr, correct_bias=True)
stu_optimizer.zero_grad()
# label_train_dataset = self.datasets[curr_task]['val']
label_train_dataset = self.datasets[curr_task]['label_train']
label_train_sampler = torch.utils.data.RandomSampler(label_train_dataset) # not distributed
if self.args.use_unlabel:
unlabel_train_dataset = self.datasets[curr_task]['unlabel_train']
unlabel_train_sampler = torch.utils.data.RandomSampler(unlabel_train_dataset) # not distributed
unlabel_batch_size = int(self.args.unlabel_amount) * self.args.train_batch_size
unlabel_train_loader = DataLoader(unlabel_train_dataset, batch_size=unlabel_batch_size, sampler=unlabel_train_sampler,
num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=curr_task))
if self.args.add_pretrain_with_ft or self.args.pretrain_first:
mix_train_file = os.path.join(self.args.data_dir, curr_task, 'pretrain.txt')
if not os.path.exists(mix_train_file):
_ = write_mix_train_file(label_train_dataset, unlabel_train_dataset, mix_train_file, os.path.join(self.args.data_dir, curr_task))
pretrain_dataloader = create_dataloader_for_pretrain(mix_train_file, self.tokz, model, self.args)
# * Get train loader.
label_train_loader = DataLoader(label_train_dataset, batch_size=self.args.train_batch_size, sampler=label_train_sampler,
num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=curr_task))
t_total = len(label_train_loader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs[curr_task]
if self.args.warmup_steps > 0:
num_warmup_steps = self.args.warmup_steps
else:
num_warmup_steps = int(t_total * self.args.warmup_proportion)
if not self.args.nouse_scheduler: # if True, not change lr.
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, t_total)
if self.args.pretrain_first and self.args.use_unlabel:
self.logger.info('BEGIN Pretraining!'+'*'*10)
for pre_epo in range(self.args.pretrain_epoch):
PRE_ITER = enumerate(pretrain_dataloader)
for k, predata in PRE_ITER:
optimizer.zero_grad()
pretrain_input_ids, pretrain_labels, pretrain_decoder_input_ids = get_model_inputs(predata, tokz=self.tokz, model_type='pretrain')
pretrain_outputs = model(input_ids=pretrain_input_ids, labels=pretrain_labels, decoder_input_ids=pretrain_decoder_input_ids)
pretrain_loss = pretrain_outputs.loss
# print('pretrain loss',pretrain_outputs.loss.item(), flush=True)
pretrain_loss.backward()
for n, p in model.named_parameters():
if curr_task in n:
p.grad.data.zero_()
optimizer.step()
self.logger.info(f'Pretrain epoch: {pre_epo}, ' +' Pretrain loss: %.4f' % (pretrain_loss.item()))
self.logger.info('Finish Pretraining!'+'*'*10)
# TODO: Training =====================================================================
# ! Mix training unlabeled and labeled data.
self.task_step = 0
for epoch in range(1, self.args.num_train_epochs[curr_task]+1):
save_steps = int(t_total // self.args.num_train_epochs[curr_task] * self.args.save_epochs)
if self.args.eval_times_per_task > 0:
eval_steps = int(t_total / self.args.eval_times_per_task)
else:
eval_steps = self.args.eval_steps
if self.rank in [0, -1]:
self.logger.info(
f'num_warmup_steps: {num_warmup_steps}, t_total: {t_total}, save_steps: {save_steps}, eval_steps: {eval_steps}')
self.logger.info("Total epoches: %d" % self.args.num_train_epochs[curr_task])
self.logger.info('*'*50+"Epoch: %d" % epoch+'*'*50)
st = time.time()
LITER = enumerate(label_train_loader)
# * Training labeled data.
for i, data in LITER:
if i==0 and epoch == 1:
question = data['batch_text'][0]['question']
print(f'For <<{curr_task}>>, the question is: "{question}"', flush=True)
question_dict[curr_task] = question
if self.args.freeze_plm:
curr_adapter_name = curr_task.replace('.','_')
model.train_adapter(curr_adapter_name)
if self.args.meantc:
ema_model.set_active_adapters(curr_adapter_name)
ema_num_tunable_params = sum(p.numel() for p in ema_model.parameters() if p.requires_grad)
assert ema_num_tunable_params is 0
else: # tune all params.
curr_adapter_name = curr_task.replace('.','_')
model.set_active_adapters(curr_adapter_name)
model.freeze_model(False)
if self.args.meantc:
ema_model.set_active_adapters(curr_adapter_name)
ema_model.freeze_model(False)
# * Get inputs of the LM model.
cls_total = get_model_inputs(data, self.global_step, self.tokz)
label_cls_loss = compute_solver_loss(model, self.loss_fn, cls_total, args=self.args)
if self.args.add_label_lm_loss:
lm_total = get_model_inputs(data, self.global_step, self.tokz, model_type='lm', task=curr_task)
label_lm_loss = compute_lm_loss(model, self.loss_fn, lm_total, args=self.args)
else: label_lm_loss = 0.0
all_label_loss = label_cls_loss + self.args.lm_lambda * label_lm_loss
all_label_loss.backward()
if self.args.add_pretrain_with_ft and self.args.use_unlabel:
pretrain_data = next(iter(pretrain_dataloader))
pretrain_input_ids, pretrain_labels, pretrain_decoder_input_ids = get_model_inputs(data, self.global_step, self.tokz, model_type='pretrain')
pretrain_outputs = model(input_ids=pretrain_input_ids, labels=pretrain_labels, decoder_input_ids=pretrain_decoder_input_ids)
print('pretrain loss',pretrain_outputs.loss.item(), flush=True)
assert self.args.add_pretrain_with_ft and self.args.use_unlabel
pretrain_loss.backward()
# * Only update backbone model parameters that are shared across tasks.
for n, p in model.named_parameters():
# if 'adapter'+curr_task in n:
if curr_task in n:
p.grad.data.zero_()
# elif 'lm_head'+curr_task in n:
# p.grad.data.zero_()
pretrain_loss = 0.1 * pretrain_outputs.loss
else: pretrain_loss = 0.0
if not self.args.debug_use_unlabel:
if self.args.stu_feedback and self.args.meantc:
fb_loss = 0.1 * compute_feedback_loss(model, ema_model, self.loss_fn, cls_total, args=self.args, kl_loss=self.kl_loss)
# fb_loss = compute_feedback_loss(model, ema_model, self.loss_fn, cls_total, args=self.args, kl_loss=self.kl_loss)
# if fb_loss <=1e-1: use_unlabel=True
if abs(fb_loss-label_cls_loss)<=self.args.feedback_threshold: use_unlabel=True
else: use_unlabel=False
elif epoch >= self.args.warmup_epoch:
use_unlabel=True
else: use_unlabel=False
else:
use_unlabel=True # Just test the memory burst situation.
if self.args.use_unlabel:
if self.args.unlabel_amount=='all':
unlabel_iter = len(unlabel_train_loader) // len(label_train_loader)
else:
unlabel_iter = int(self.args.unlabel_amount) # How many batches of unlabeled data used per label batch.
# for j in range(1):
for j in range(unlabel_iter):
unlabel_data = next(iter(unlabel_train_loader))
if self.args.add_unlabel_lm_loss:
unlabel_lm_total = get_model_inputs(unlabel_data, self.global_step, self.tokz, model_type='lm', task=curr_task)
unlabel_lm_loss = (1/unlabel_iter) * self.args.lm_lambda * compute_lm_loss(model, self.loss_fn, unlabel_lm_total, args=self.args)
unlabel_lm_loss.backward()
else: unlabel_lm_loss = 0.0
#* Use the current task unlabeled data to augment the memory for old tasks.
if self.args.backward_augment:
curr_keys, curr_values, curr_questions = get_sentence_embedding(model, unlabel_data, args=self.args)
curr_unlabel_memory.push(curr_keys, curr_values, curr_questions)
model.train_adapter(curr_adapter_name)
#! Use the unlabeled data for consistency regularization.
if use_unlabel:
if self.args.input_aug:
unlabeled_batch_text = unlabel_data['batch_text']
aug_unlabel_data = aug_unlabeled_data(unlabeled_batch_text, self.tokz, self.args, curr_task)
un_cls_total = get_model_inputs(aug_unlabel_data, self.global_step, self.tokz, data_type='unlabeled')
else:
un_cls_total = get_model_inputs(unlabel_data, self.global_step, self.tokz, data_type='unlabeled')
consistency_loss = (1 / unlabel_iter) * self.args.ungamma * compute_consistency_loss(model, ema_model, self.loss_fn, un_cls_total, args=self.args, kl_loss=self.kl_loss, tokz=self.tokz, epoch=epoch, task_name=curr_task)
consistency_loss.backward()
#! Augment the current unlabeled data with pseudo unlabeled data retrieved from the memory.
if self.args.construct_memory and len(prev_tasks)>=1 and self.args.forward_augment:
assert memory is not None
all_tasks_neighbors = []
for prev_task in prev_tasks:
querys, _, query_questions = get_sentence_embedding(model, unlabel_data, args=self.args) # For current task
# Get neighbors with the corresponding questions.
neighbors = get_neighbors(querys, prev_task, self.args.kneighbors, memory=memory, args=self.args, questions=query_questions)
model.train_adapter(curr_adapter_name)
all_tasks_neighbors += neighbors
if len(all_tasks_neighbors)>0:
print(f'The constructed memory contains "{len(all_tasks_neighbors)}" neighbors to augment.',flush=True)
pseudo_unlabel_batch = create_batch_from_memory(all_tasks_neighbors, self.tokz, self.args, curr_task)
if pseudo_unlabel_batch is not None:
pseudo_dataset = MemoryDataset(pseudo_unlabel_batch)
pseudo_sampler = torch.utils.data.RandomSampler(pseudo_dataset)
pseudo_loader = DataLoader(pseudo_dataset, batch_size=self.args.train_batch_size, sampler=pseudo_sampler, num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=None))
for _, pseudo_batch in enumerate(pseudo_loader):
# self.logger.info('Finish creating batch with the retrieved pseudo unlabeled data from previously learned tasks.')
pseudo_un_cls_total = get_model_inputs(pseudo_batch, self.global_step, self.tokz, data_type='unlabeled')
consistency_loss_from_memory = (1/unlabel_iter) * self.args.ungamma * compute_consistency_loss(model, ema_model, self.loss_fn, pseudo_un_cls_total, args=self.args, kl_loss=self.kl_loss, tokz=self.tokz, epoch=epoch, task_name=curr_task)
consistency_loss_from_memory.backward()
print(f'forward memory loss: {consistency_loss_from_memory.item()}')
if (i + 1) % self.args.gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if not self.args.nouse_scheduler:
scheduler.step() # Can change the lr.
if self.args.meantc:
# update_ema_variables(model, ema_model, self.args.ema_decay, self.global_step)
update_ema_variables(model, ema_model, self.args.ema_decay, self.task_step)
optimizer.zero_grad()
self.task_step += 1
# Update the learning rate.
lr = scheduler.get_last_lr()[0]
# Write in the logger.
self.logger.info(f'EPOCH: {epoch}, task step: {self.task_step}, global step: {self.global_step} ' + 'Training loss: %.4f' % (label_cls_loss) + ' LM loss: %.4f' % (label_lm_loss))
# Log info to Tensorboard
self.train_writer.add_scalar('train_label_cls_loss', label_cls_loss, self.global_step)
self.train_writer.add_scalar('train_label_lm_loss', label_lm_loss, self.global_step)
self.train_writer.add_scalar('time_per_batch', time.time() - st, self.global_step)
self.train_writer.add_scalar('lr', lr, self.global_step)
st = time.time()
self.global_step += 1
# TODO: Evaluate model on Valid datasets. ================================================================================================
if self.global_step % self.args.eval_steps == 0 and self.args.eval_steps > 0:
all_res_per_eval = {}
avg_metric_loss, avg_metric = [], []
tc_avg_metric = []
model.eval()
for val_task in self.args.tasks[:len(prev_tasks)+1]:
eval_loss = self._eval(model, val_task, self.loss_fn)
eval_metric = self._eval_score(model, val_task, out_dir=os.path.join(self.args.output_dir, curr_task) if self.args.log_eval_res else None)
if self.args.meantc:
tc_eval_metric = self._eval_score(ema_model, val_task, out_dir=os.path.join(self.args.output_dir, curr_task+'_teacher') if self.args.log_eval_res else None)
avg_metric_loss.append(float(eval_loss['loss']))
avg_metric.append(float(eval_metric['score']))
if self.args.meantc:
tc_avg_metric.append(float(tc_eval_metric['score']))
for key, value in eval_loss.items():
self.valid_writer.add_scalar(f'{val_task}/{key}', value, self.global_step)
for key, value in eval_metric.items():
self.valid_writer.add_scalar(f'{val_task}/{key}', value, self.global_step)
if self.args.meantc:
for key, value in tc_eval_metric.items():
self.valid_writer.add_scalar(f'{val_task}/tc_{key}', value, self.global_step)
self.logger.info('Evaluation! '+f'Current task: {curr_task}, task_epoch: {epoch}, task step: {self.task_step}, ' + f'global step: {self.global_step}, {val_task}: teacher eval metric: {tc_eval_metric}')
self.logger.info('Evaluation! '+f'Current task: {curr_task}, task_epoch: {epoch}, task step: {self.task_step}, ' + f'global step: {self.global_step}, {val_task}: eval metric: {eval_metric}')
all_res_per_eval[val_task] = {'score': eval_metric['score']}
# print(all_res_per_eval,flush=True)
all_tasks_res.append(all_res_per_eval)
self.valid_writer.add_scalar(f'avg/score', sum(avg_metric)/len(avg_metric), self.global_step)
if self.args.meantc:
self.valid_writer.add_scalar(f'avg/tc_score',sum(tc_avg_metric)/len(tc_avg_metric), self.global_step)
model.train()
# if not self.args.nouse_scheduler:
# scheduler.step() # Can change the lr.
self.logger.info("Training loop. The %dth epoch completed."%epoch)
if self.args.backward_augment and len(prev_tasks)>0:
aug_optimizer =AdamW(optimizer_grouped_parameters, lr=self.args.lr, correct_bias=True)
for g in aug_optimizer.param_groups:
g['lr'] = 5e-6
# g['lr'] = 1e-6
aug_optimizer.zero_grad()
self.logger.info(f'Begin backward augmentation training!')
for old_task in self.args.tasks[:len(prev_tasks)]:
old_adapter_name = old_task.replace('.','_')
model.set_active_adapters(old_adapter_name)
ema_model.set_active_adapters(old_adapter_name)
model.train_adapter(old_adapter_name)
old_memory = memory[old_task]
aug_old_unlabel_batch = create_batch_to_augment_memory(old_task, old_memory, curr_unlabel_memory, tokz=self.tokz, args=self.args)
if aug_old_unlabel_batch is not None:
aug_dataset = MemoryDataset(aug_old_unlabel_batch)
aug_sampler = torch.utils.data.RandomSampler(aug_dataset) # not distributed
aug_loader = DataLoader(aug_dataset, batch_size=self.args.train_batch_size, sampler=aug_sampler,
num_workers=self.args.num_workers, pin_memory=True, collate_fn=PadBatchSeq(self.tokz.pad_token_id, tokz=self.tokz, task=None))
for _, aug_batch in enumerate(aug_loader):
# for i in range(2):
aug_old_untotal = get_model_inputs(aug_batch, self.global_step, self.tokz, data_type='unlabeled')
backaug_consistency_loss = self.args.ungamma * compute_consistency_loss(model, ema_model, self.loss_fn, aug_old_untotal, args=self.args, kl_loss=self.kl_loss, tokz=self.tokz, epoch=i, task_name=old_task)
backaug_consistency_loss.backward()
print(f'backward augment loss {backaug_consistency_loss.item()}')
aug_optimizer.step()
aug_optimizer.zero_grad()
update_ema_variables(model, ema_model, self.args.ema_decay, 1e5)
# * Test all learned tasks after training the certain epochs.
if self.args.test_all:
last_avg_metric = []
last_tc_avg_metric = []
curr_task_score = {}
prev_tasks_scores = []
for val_task in self.args.tasks[:len(prev_tasks)+1]:
last_eval_metric = self._eval_score(model, val_task, out_dir=os.path.join(self.args.output_dir, curr_task) if self.args.log_eval_res else None, last=True)
last_tc_eval_metric = self._eval_score(ema_model, val_task, out_dir=os.path.join(self.args.output_dir, curr_task+'_teacher') if self.args.log_eval_res else None, last=True)
last_avg_metric.append(float(last_eval_metric['score']))
last_tc_avg_metric.append(float(last_tc_eval_metric['score']))
if val_task == curr_task:
curr_task_score['score'] = float(last_eval_metric['score'])
curr_task_score['tc_score'] = float(last_tc_eval_metric['score'])
else:
prev_tasks_scores.append({val_task:{'score':float(last_eval_metric['score']),'tc_score':float(last_tc_eval_metric['score'])}})
all_tasks_res.append({curr_task+'_finish':curr_task_score})
# Average scores of previous learned task except the current learned task.
if len(prev_tasks)>=1:
avg_old_stu_score = sum(last_avg_metric[:-1])/len(last_avg_metric[:-1])
avg_old_tc_score = sum(last_tc_avg_metric[:-1])/len(last_tc_avg_metric[:-1])
average_old_score = {'score': avg_old_stu_score,'tc_score':avg_old_tc_score}
all_tasks_res.append({'prev_tasks_scores':prev_tasks_scores})
all_tasks_res.append({'Average_wo_curr_task':average_old_score})
# Average scores of all learned tasks.
avg_stu_score = sum(last_avg_metric)/len(last_avg_metric)
avg_tc_score = sum(last_tc_avg_metric)/len(last_tc_avg_metric)
self.valid_writer.add_scalar(f'avg_alldata/score', avg_stu_score, self.global_step)
self.logger.info(f'Last Evaluation! Average score: {avg_stu_score}; Average teacher score: {avg_tc_score}'+'*'*10)
average_score = {'score': avg_stu_score,'tc_score': avg_tc_score}
all_tasks_res.append({'Average':average_score})
# TODO: Model saving. ------------------------------------------------------------------------
task_final_save_path = os.path.join(self.save_folder, str(curr_task)+'_model'+'.pt')
self.logger.info('Saving model checkpoint of task %s to %s'%(curr_task, task_final_save_path))
torch.save(model.state_dict(), task_final_save_path)
model_save_path = os.path.join(self.save_folder, 'model_last.pt')
torch.save(model.state_dict(), model_save_path)
if self.args.meantc:
ema_task_save_path = os.path.join(self.save_folder, str(curr_task)+'_ema_model'+'.pt')
ema_all_save_path = os.path.join(self.save_folder, 'ema_model_last.pt')
torch.save(ema_model.state_dict(), ema_task_save_path)
torch.save(ema_model.state_dict(), ema_all_save_path)
ema_model.save_adapter(os.path.join(self.save_folder,curr_adapter_name+'_adapter'), curr_adapter_name) # save teacher adapter.
self.logger.info('Saving total model checkpoint to %s', model_save_path)
self.logger.info('='*25+'Training complete!'+'='*25)
return model, ema_model, all_tasks_res, question_dict
def _eval(self, model, task, loss_fn):
lm_recorder = MetricsRecorder(self.device, 'loss')
with torch.no_grad():
adapter_name = task.replace('.','_')
model.set_active_adapters(adapter_name)
model.eval()
if self.args.test_overfit:
test_dataset = self.data_loaders[task]['label_train']
else:
test_dataset = self.data_loaders[task]['test']
for i, data in enumerate(test_dataset):
label_total = get_model_inputs(data, self.global_step, self.tokz)
# label_lm_loss = compute_lm_loss(self.device, model, self.loss_fn, label_total, has_label=True, args=self.args, kl_loss=self.kl_loss)
loss = compute_solver_loss(model, self.loss_fn, label_total, args=self.args)
# loss = compute_lm_loss(self.device, model, loss_fn, lm_total, has_label=True, args=self.args)
lm_recorder.metric_update(i, {'loss': loss})
if i >= 35:
break
if self.rank != -1:
lm_recorder.all_reduce()
return lm_recorder
# * Calculate accuracy of the model
def _eval_score(self, model, task, out_dir=None,last=False, use_adapter_name=None):
max_ans_len = max(self.datasets[task]['label_train'].max_ans_len, self.datasets[task]
['val'].max_ans_len, self.datasets[task]['test'].max_ans_len) + 1
# print('Maximum answer length',max_ans_len,flush=True)
assert task is not None
max_ans_len = max_ans_length_dict[task]
if out_dir is not None:
if not os.path.exists(out_dir):
os.makedirs(out_dir)
out_dir = os.path.join(out_dir, f'{task}_step{self.global_step}.json')
out_dir_content = []
with torch.no_grad():
if use_adapter_name:
model.set_active_adapters(use_adapter_name)
else:
adapter_name = task.replace('.','_')
model.set_active_adapters(adapter_name)
# print('model of task: {}'.format(task),list(model.parameters()), flush=True)
model.eval()
pred_ans_all, gold_ans_all = [], []
context_all = []
all_pred_task_name = []
if task in ['wikisql','woz.en', 'squad']:
all_gold_list = []
if self.args.test_overfit:
test_dataloader = self.data_loaders[task]['label_train']
else:
test_dataloader = self.data_loaders[task]['test']
for i, data in enumerate(test_dataloader):
example = data['context_id'].cuda()
example_mask = data['context_mask'].cuda()
if out_dir is not None:
context_all.append(example)
gold_ans = data['ans_id'].cuda()
# Print model inputs to check.
# if i==0:
# print('context:',self.tokz.decode(example[0]),flush=True)
# print('ans:',self.tokz.decode(gold_ans[0]),flush=True)
pred_ans, _ = get_answer(self.tokz, model, example, example_mask, max_ans_len, args=self.args, is_eval=True)
# print('context shape',context.shape,flush=True)
# pred_ans = pred_ans[:,context.shape[1]:]
pred_ans = pred_ans[:,1:] # Remove the first <pad> token.
# print('pred',pred_ans,flush=True)
# print('gold',gold_ans,flush=True) # output tensors
pred_ans_all.append(pred_ans)
gold_ans_all.append(gold_ans)
if task in ['wikisql','woz.en', 'squad']:
for k in range(len(data['context_id'])):
gold_ans_list = data['batch_text'][k]['ans_list']
all_gold_list.append(gold_ans_list)
# if i>=1: break
if not last and task not in ['wikisql','woz.en','squad']:
if self.args.train_batch_size==64:
if i >= 35: break
elif self.args.train_batch_size==16:
if i >= 35: break
# collect results from different nodes
pred_ans_all = communicate_tensor(pred_ans_all, pad_token=self.tokz.pad_token_id).tolist()
gold_ans_all = communicate_tensor(gold_ans_all, pad_token=self.tokz.pad_token_id).tolist()
if out_dir is not None:
context_all = communicate_tensor(context_all, pad_token=self.tokz.pad_token_id).tolist()
correct_count, total_count = 0, len(pred_ans_all)
qa_results = []
pred_list = []
print(f'Total number of evaluate data is {total_count}',flush=True)
for i in range(total_count):
if out_dir is not None:
res = {}
# print('context_all',context_all[i],flush=True)
context = self.tokz.decode(strip_list(context_all[i], self.tokz.eos_token_id), skip_special_tokens=True)
res['context'] = context
pred = self.tokz.decode(cut_eos(pred_ans_all[i], self.tokz.eos_token_id), skip_special_tokens=True)
res['ans_pred'] = pred
pred_list.append(pred)
#* Get ground truth
if task=='woz.en':
# print(self.datasets[task]['test'].answers)
gold = self.datasets[task]['test'].answers[i]
res['ans_gold'] = gold[1]
elif task == 'wikisql':
gold = self.datasets[task]['test'].answers[i]
res['ans_gold'] = gold['answer']
elif task == 'squad':
gold = all_gold_list[i]
assert type(gold) == list
qa_results.append([pred, gold])
res['ans_gold'] = gold[0]
else:
gold = self.tokz.decode(cut_eos(gold_ans_all[i], self.tokz.eos_token_id), skip_special_tokens=True)
qa_results.append([pred,[gold]])
res['ans_gold'] = gold
out_dir_content.append(res)
if task in ['wikisql','woz.en']:
test_dataset = self.datasets[task]['test']
ids = test_dataset.get_indices()
test_dataset.sort_by_index()
qa_results = [x[1] for x in sorted([(i,g) for i, g in zip(ids, pred_list)])]
for i in range(len(test_dataset)):
Y = test_dataset.answers[i]
qa_results[i] = [qa_results[i], Y]
score_dict = compute_metrics(qa_results,
bleu='iwslt.en.de' in task or 'multinli.in.out' in task,
dialogue='woz.en' in task,
rouge='cnn_dailymail' in task,
logical_form='wikisql' in task,
corpus_f1='zre' in task )
self.logger.info(f'Score dictionary for task {task}>> {score_dict}')
metric = task2metric_dict[task]
score = score_dict[metric]
if out_dir is not None:
with open(out_dir, 'w', encoding='utf-8') as outfile:
for res in out_dir_content:
print(json.dumps(res), file=outfile)
model.train()
return {'score': '{:.6f}'.format(score)}
# * Train across tasks through time.
def train(self, tasks):
model = None
ema_model = None
self.logger.info('Total tasks number is '+str(len(tasks)))
all_tasks_res = []
res_file = os.path.join(self.save_folder, 'metrics.json')
question_dict = {}
for i in range(len(tasks)):
model, ema_model, all_tasks_res, question_dict = self._train(model, ema_model, curr_task=tasks[i], prev_tasks=tasks[:i], all_tasks_res=all_tasks_res, question_dict=question_dict)
self.logger.info('We have trained %d tasks'%(i+1))
if self.args.use_memory:
self.logger.info('Memory has saved information of %d tasks'%(len(memory.memory.keys()))+'-'*10)
# * Evaluation metrics saving.
with open(res_file,'w', encoding='utf-8') as f:
for e in all_tasks_res:
print(json.dumps(e,ensure_ascii=False), file=f)

358
ssll/unifymodel/utils.py Normal file
View File

@ -0,0 +1,358 @@
from unifymodel.dataset import PadBatchSeq, pad_seq, get_unlabel_data
import logging
import random
import torch
import numpy as np
from torch.utils.data import DataLoader
import json
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
import torch.distributed as dist
import os, time, gc, json, pickle, argparse, math, re
import torch.nn as nn
import torch.utils.data as data
import torch.distributed as dist
import torch.multiprocessing as mp
import copy
from transformers import GPT2Config, GPT2Model, GPT2LMHeadModel, GPT2Tokenizer
import torch.nn.functional as F
from tools.eda import *
loggers = {}
def get_logger(filename, level=logging.INFO, print2screen=True):
global loggers
import logging
if os.path.exists(filename):
os.remove(filename)
if loggers.get(filename):
return loggers.get(filename)
else:
logger = logging.getLogger(filename)
logger.setLevel(level)
fh = logging.FileHandler(filename, encoding='utf-8')
fh.setLevel(level)
ch = logging.StreamHandler()
ch.setLevel(level)
formatter = logging.Formatter('[%(asctime)s][%(filename)s][line: %(lineno)d][%(levelname)s] >> %(message)s')
# formatter = logging.Formatter('[%(asctime)s][%(thread)d][%(filename)s][line: %(lineno)d][%(levelname)s] >> %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
if print2screen:
logger.addHandler(ch)
loggers[filename] = logger
return logger
def frange_cycle_linear(n_iter, start=0.01, stop=1.0, n_cycle=4, ratio=0.5):
L = np.ones(n_iter) * stop
period = n_iter/n_cycle
step = (stop-start)/(period*ratio) # linear schedule
for c in range(n_cycle):
v, i = start, 0
while v <= stop and (int(i+c*period) < n_iter):
L[int(i+c*period)] = v
v += step
i += 1
return L
def num_params(model):
return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])
def switch_schedule(schedule, mult, switch):
""" Apply LR multiplier before iteration "switch" """
def f(e):
s = schedule(e)
if e < switch:
return s * mult
return s
return f
def linear_schedule(args):
def f(e):
if e <= args.warmup:
return e / args.warmup
return max((e - args.iterations) / (args.warmup - args.iterations), 0)
return f
def get_mask(x_len):
mask = torch.arange(max(x_len), device=x_len.device)[None, :] < x_len[:, None] # [bs, max_len]
return mask.bool()
def get_reverse_mask(x_len):
mask = torch.arange(max(x_len)-1, -1, -1, device=x_len.device)[None, :] < x_len[:, None] # [bs, max_len]
return mask.bool()
def compare_tokens(x, y, eos_id):
if eos_id in x:
x = x[:x.index(eos_id)]
if eos_id in y:
y = y[:y.index(eos_id)]
return x == y
def pad_tensor(tensor, length, pad_token=0):
return torch.cat([tensor, tensor.new(tensor.size(0), length - tensor.size()[1]).fill_(pad_token)], dim=1)
def seed_everything(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['PYTHONHASHSEED'] = str(seed)
torch.cuda.manual_seed_all(seed)
def communicate_tensor(tensor_list, pad_token=0):
'''
collect tensors from all processes
'''
if len(tensor_list) == 0:
return None
device = tensor_list[0].device
max_len = torch.tensor(max([i.shape[1] for i in tensor_list]), dtype=torch.int64, device=device)
if dist.is_initialized(): # Obtain the max_len of the second dim of each tensor
dist.all_reduce(max_len, op=dist.ReduceOp.MAX)
# Pad tensors to the max_len
tensor = torch.cat([pad_tensor(i, max_len, pad_token) for i in tensor_list], dim=0)
tensor_bs = torch.tensor(tensor.shape[0], dtype=torch.int64, device=device)
max_tensor_bs = torch.tensor(tensor.shape[0], dtype=torch.int64, device=device)
if dist.is_initialized():
dist.all_reduce(max_tensor_bs, op=dist.ReduceOp.MAX) # Obtain the max_tensor_bs of each tensor
if max_tensor_bs != tensor_bs:
tensor = torch.cat([tensor, tensor.new(max_tensor_bs-tensor_bs, tensor.shape[1]).fill_(pad_token)], dim=0)
# Gather padded tensors and the bs of each tensor
tensor_list = [torch.ones_like(tensor).fill_(pad_token) for _ in range(dist.get_world_size())]
tensor_bs_list = [torch.ones_like(tensor_bs).fill_(pad_token) for _ in range(dist.get_world_size())]
dist.all_gather(tensor_list=tensor_list, tensor=tensor.contiguous())
dist.all_gather(tensor_list=tensor_bs_list, tensor=tensor_bs)
# Cut the padded batch
for i in range(dist.get_world_size()):
tensor_list[i] = tensor_list[i][:tensor_bs_list[i]]
tensor = torch.cat(tensor_list, dim=0)
return tensor
class MetricsRecorder(object):
def __init__(self, device, *metric_names):
self.metric_names = list(metric_names)
self.device = device
self.metrics = {}
for metric_name in metric_names:
self.metrics[metric_name] = torch.tensor(0, dtype=torch.float64, device=self.device)
def metric_update(self, batch_no, metric_values_dict):
for k, v in metric_values_dict.items():
self.metrics[k] = (self.metrics[k] * batch_no + v) / (batch_no + 1)
def add_to_writer(self, writer, step, group):
for n in self.metric_names:
m = self.metrics[n].item()
writer.add_scalar('%s/%s' % (group, n), m, step)
def write_to_logger(self, logger, epoch, step):
log_str = 'epoch {:>3}, step {}'.format(epoch, step)
for n in self.metric_names:
m = self.metrics[n].item()
log_str += ', %s %g' % (n, m)
logger.info(log_str)
def items(self):
return self.metrics.items()
def all_reduce(self):
for n in self.metric_names:
torch.distributed.all_reduce(self.metrics[n], op=torch.distributed.ReduceOp.SUM)
self.metrics[n] /= torch.distributed.get_world_size()
def __getitem__(self, k):
return self.metrics[k]
def __setitem__(self, key, value):
self.metrics[key] = value
def __repr__(self):
return self.metrics.__repr__()
def __str__(self):
return self.metrics.__str__()
def keys(self):
return self.metrics.keys()
def cut_eos(seq, eos_id):
if eos_id not in seq:
return seq
return seq[:seq.index(eos_id)]
def cal_metrics_from_pred_files(res_file):
with open(res_file, 'r', encoding='utf-8') as f:
res = [json.loads(i) for i in f.readlines()]
y_true = [i['ans_gold'] for i in res]
y_pred = [i['ans_pred'] for i in res]
return {
"accuracy": accuracy_score(y_true, y_pred),
}
def slot_f1_score(pred_slots, true_slots):
'''
pred_slots, true_slots are like [['from_location:10-11', 'leaving_date:12-13']]
'''
slot_types = set([slot.split(":")[0] for row in true_slots for slot in row])
slot_type_f1_scores = []
for slot_type in slot_types:
predictions_for_slot = [[p for p in prediction if slot_type in p] for prediction in pred_slots] # [['from_location'],[],[],['from_location']]
labels_for_slot = [[l for l in label if slot_type in l] for label in true_slots]
proposal_made = [len(p) > 0 for p in predictions_for_slot]
has_label = [len(l) > 0 for l in labels_for_slot]
prediction_correct = [prediction == label for prediction, label in zip(predictions_for_slot, labels_for_slot)]
true_positives = sum([
int(proposed and correct)
for proposed, correct in zip(proposal_made, prediction_correct)])
num_predicted = sum([int(proposed) for proposed in proposal_made])
num_to_recall = sum([int(hl) for hl in has_label])
precision = true_positives / (1e-5 + num_predicted)
recall = true_positives / (1e-5 + num_to_recall)
f1_score = 2 * precision * recall / (1e-5 + precision + recall)
slot_type_f1_scores.append(f1_score)
return np.mean(slot_type_f1_scores)
def textid_decode(text, eos, tokz, contain_eos=False):
if eos in text:
if contain_eos:
idx = text.index(eos)
text = text[:idx]
else:
text = text[:text.index(eos)]
text_id = text
len_text = len(text)
text = tokz.decode(text).strip()
return text, len_text, text_id
def padding_convert(text_list, eos):
tt_list = []
for text in text_list:
if eos in text:
eos_indexs = [i for i, x in enumerate(text) if x==eos]
# print('eos index', eos_indexs,flush=True)
if len(eos_indexs)>1: text = text[:eos_indexs[1]]
tt_list.append(text)
tt_lens = [len(i) for i in tt_list]
tt_lens = torch.tensor(tt_lens, dtype=torch.long).to('cuda')
tt_pad = torch.tensor([pad_seq(i, eos, max(tt_lens), pad_left=True) for i in tt_list], dtype=torch.long).to('cuda')
return tt_pad, tt_lens
def strip_list(seq, eos_id):
l, r = 0, len(seq)-1
for i in range(len(seq)):
if seq[i] != eos_id:
break
l = i
for i in range(len(seq)-1, -1, -1):
if seq[i] != eos_id:
break
r = i
return seq[l+1:r]
def aug_unlabeled_data(batch, tokz, args, task_name):
# * Augment sentence with EDA.
input_sen, context_sen, all_sen, ans_sen = [],[],[],[]
task_prefix = task_name + ':'
raw_sen, ques_sen = [],[]
for i in batch:
text = i['raw_text']
ans_text = i['ans_sen']
question = i['question']
aug_text_list = eda(text, num_aug=args.num_aug)
# aug_text_list = eda(text, num_aug=1)
for aug_text in aug_text_list:
input_text = task_prefix + aug_text + '<QUES>' + question
input_sen.append(input_text)
context_sen.append(input_text+'<ANS>')
# all_sen.append(aug_text+'<ANS>'+ans_text)
all_sen.append(aug_text+'<QUES>') # train LM only on the inputs
ans_sen.append(ans_text)
raw_sen.append(aug_text)
ques_sen.append(question)
input_encoding = tokz(input_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
all_encoding = tokz(all_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
context_encoding = tokz(context_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
ans_encoding = tokz(ans_sen, padding='longest', max_length=args.max_ans_len, truncation=True, return_tensors='pt')
raw_text_encoding = tokz(raw_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
question_encoding = tokz(ques_sen, padding='longest', max_length=args.max_input_len, truncation=True, return_tensors='pt')
res = {}
res["input_id"], res['input_mask'] = input_encoding.input_ids, input_encoding.attention_mask
res["all_id"], res['all_mask'] = all_encoding.input_ids, all_encoding.attention_mask
res["context_id"], res['context_mask'] = context_encoding.input_ids, context_encoding.attention_mask
res["ans_id"], res['ans_mask'] = ans_encoding.input_ids, ans_encoding.attention_mask
res["raw_id"], res['raw_mask'] = raw_text_encoding.input_ids, raw_text_encoding.attention_mask
res["ques_id"], res['ques_mask'] = question_encoding.input_ids, question_encoding.attention_mask
res['batch_text'] = batch
return res
def get_current_consistency_weight(args, epoch):
# Consistency ramp-up from https://arxiv.org/abs/1610.02242
return args.consistency * sigmoid_rampup(epoch, args.consistency_rampup)
def sigmoid_rampup(current, rampup_length):
"""Exponential rampup from https://arxiv.org/abs/1610.02242"""
if rampup_length == 0:
return 1.0
else:
current = np.clip(current, 0.0, rampup_length)
phase = 1.0 - current / rampup_length
return float(np.exp(-5.0 * phase * phase))
class KDLoss(nn.Module):
def __init__(self, KD_term=0.0, T=1.0):
super(KDLoss, self).__init__()
assert 0 <= KD_term <=5
assert 0 < T
self.KD_term = KD_term
self.T = T
def forward(self, output_logits, teacher_logits, label_mask=None):
KD_loss = F.kl_div(F.log_softmax(output_logits / self.T, dim=1),
F.softmax(teacher_logits / self.T, dim=1), reduction='none')
# print('label_mask shape:',label_mask.shape,'loss shape:',KD_loss.shape,flush=True)
# KD_loss [batch_size*seq_len, 50528]
KD_loss = torch.sum(KD_loss, dim=1)
if label_mask is not None:
label_mask = label_mask.view(-1)
KD_loss = KD_loss.where(label_mask.cuda(), torch.tensor(0.0).cuda())
kd_loss = KD_loss.sum() / label_mask.sum()
else:
kd_loss = KD_loss
return kd_loss * self.KD_term * self.T * self.T
# def forward(self, output_logits, teacher_logits=None):
# KD_loss = F.kl_div(F.log_softmax(output_logits / self.T, dim=1),
# F.softmax(teacher_logits / self.T, dim=1), reduction='none')
# KD_loss = torch.sum(KD_loss, dim=1)
# return KD_loss * self.KD_term * self.T * self.T # T^2 is a trick to make KL loss scale invariant to temperature
def kl_divergence(mean1, logvar1, mean2, logvar2):
# print(mean1.size(),logvar1.size(),mean2.size(),logvar2.size(),flush=True)
exponential = logvar1 - logvar2 - \
torch.pow(mean1 - mean2, 2) / logvar2.exp() - \
torch.exp(logvar1 - logvar2) + 1
result = -0.5 * torch.sum(exponential, tuple(range(1, len(exponential.shape))))
return result.mean()

56
ssll/unitrain.py Normal file
View File

@ -0,0 +1,56 @@
from unifymodel.utils import get_logger, seed_everything
from unifymodel.trainer import Trainer
from unifymodel.dataset import get_datasets
import torch
from settings import parse_args
import os
from transformers import AdamW, get_linear_schedule_with_warmup, Conv1D
from transformers import T5Tokenizer, T5Model
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = parse_args()
logger = get_logger(args.log_file)
if args.local_rank in [0, -1]:
logger.info('Pytorch Version: {}'.format(torch.__version__))
for k, v in vars(args).items():
logger.info("{}= {}".format(k, v))
# seed everything
seed_everything(args.seed)
cache_dir = os.path.join(args.output_dir, 'model_cache')
os.makedirs(cache_dir, exist_ok=True)
# * Add special tokens.
special_tokens_dict = {'additional_special_tokens': ['<ANS>','<QUES>']}
out = os.path.join(cache_dir,'out')
tokz = T5Tokenizer.from_pretrained('t5-base',cache_dir=out)
num_spe_token = tokz.add_special_tokens(special_tokens_dict)
if args.use_task_pmt:
pre_token_list = []
for task in args.tasks:
pre_token_list += [str(task)+':']
else:
pre_token_list = ['TASK:' for i in range(len(args.tasks))]
num_add_tokens = tokz.add_tokens(pre_token_list)
print('We have added', num_spe_token+num_add_tokens, 'tokens to T5', flush=True)
tokz.save_pretrained(out)
if args.local_rank in [0, -1]:
logger.info('Loading datasets...'+'.'*10)
datasets = get_datasets(args.data_dir, args.tasks, tokz, max_input_len=args.max_input_len)
logger.info('Finish loading datasets!')
if args.use_memory:
memory = Memory()
else:
memory = None
trainer = Trainer(args, tokz, datasets, logger, cache_dir, memory=memory)
trainer.train(args.tasks)