828 lines
35 KiB
Python
828 lines
35 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
import logging
|
|
import os
|
|
os.environ['NCCL_ASYNC_ERROR_HANDLING']='1'
|
|
import torch
|
|
import copy,random
|
|
import sys
|
|
import json
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
from typing import List, Optional, Tuple
|
|
from sklearn.cluster import KMeans
|
|
# from data import QAData
|
|
sys.path.append("../")
|
|
from models.metat5notask import T5ForConditionalGeneration as PromptT5
|
|
from downstream.dataset_processors import *
|
|
from downstream.l2ptrainer import QuestionAnsweringTrainer
|
|
import datasets
|
|
import numpy as np
|
|
from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets
|
|
import os
|
|
|
|
from functools import partial
|
|
os.environ["WANDB_DISABLED"] = "true"
|
|
import transformers
|
|
from transformers import (
|
|
AutoConfig,
|
|
AutoModelForSeq2SeqLM,
|
|
AutoTokenizer,
|
|
DataCollatorForSeq2Seq,
|
|
HfArgumentParser,
|
|
M2M100Tokenizer,
|
|
MBart50Tokenizer,
|
|
EvalPrediction,
|
|
MBart50TokenizerFast,
|
|
MBartTokenizer,
|
|
MBartTokenizerFast,
|
|
Seq2SeqTrainer,
|
|
Seq2SeqTrainingArguments,
|
|
default_data_collator,
|
|
set_seed,
|
|
)
|
|
from pathlib import Path
|
|
from transformers.trainer_utils import get_last_checkpoint
|
|
from transformers.utils import check_min_version
|
|
from transformers.utils.versions import require_version
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
|
|
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
|
|
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
|
|
format2dataset = {
|
|
'extractive':['squad','extractive','newsqa','quoref'],
|
|
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
|
|
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
|
|
}
|
|
seed_datasets = ['race','narrativeqa','squad']
|
|
dataset2format= {}
|
|
|
|
for k,vs in format2dataset.items():
|
|
for v in vs:
|
|
dataset2format[v] = k
|
|
|
|
task2id = {'squad': 0, 'narrativeqa': 1, 'race': 2, 'newsqa':3,'quoref':4,'drop':5,'nqopen':6,'openbookqa':7,'mctest':8,'social_iqa':9,'dream':10}
|
|
|
|
@dataclass
|
|
class ModelArguments:
|
|
"""
|
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
|
"""
|
|
|
|
model_name_or_path: str = field(
|
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
|
)
|
|
config_name: Optional[str] = field(
|
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
|
)
|
|
tokenizer_name: Optional[str] = field(
|
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
|
)
|
|
cache_dir: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
|
)
|
|
use_fast_tokenizer: bool = field(
|
|
default=True,
|
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
|
)
|
|
model_revision: str = field(
|
|
default="main",
|
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
|
)
|
|
use_auth_token: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
|
"with private models)."
|
|
},
|
|
)
|
|
fix_t5: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "whether to fix the main parameters of t5"}
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class DataTrainingArguments:
|
|
"""
|
|
Arguments pertaining to what data we are going to input our model for training and eval.
|
|
"""
|
|
|
|
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
|
|
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
|
|
|
|
dataset_name: Optional[str] = field(
|
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
|
)
|
|
dataset_config_name: Optional[str] = field(
|
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
|
)
|
|
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
|
|
validation_file: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
|
|
"a jsonlines file."
|
|
},
|
|
)
|
|
test_file: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
|
|
},
|
|
)
|
|
overwrite_cache: bool = field(
|
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
|
)
|
|
after_train: Optional[str] = field(default=None)
|
|
preprocessing_num_workers: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "The number of processes to use for the preprocessing."},
|
|
)
|
|
max_source_length: Optional[int] = field(
|
|
default=1024,
|
|
metadata={
|
|
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded."
|
|
},
|
|
)
|
|
max_target_length: Optional[int] = field(
|
|
default=128,
|
|
metadata={
|
|
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded."
|
|
},
|
|
)
|
|
val_max_target_length: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
|
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
|
"during ``evaluate`` and ``predict``."
|
|
},
|
|
)
|
|
max_seq_length: int = field(
|
|
default=384,
|
|
metadata={
|
|
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded."
|
|
},
|
|
)
|
|
pad_to_max_length: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Whether to pad all samples to model maximum sentence length. "
|
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
|
"efficient on GPU but very bad for TPU."
|
|
},
|
|
)
|
|
max_train_samples: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
|
"value if set."
|
|
},
|
|
)
|
|
max_eval_samples: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
|
"value if set."
|
|
},
|
|
)
|
|
max_predict_samples: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
|
"value if set."
|
|
},
|
|
)
|
|
num_beams: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
|
"which is used during ``evaluate`` and ``predict``."
|
|
},
|
|
)
|
|
ignore_pad_token_for_loss: bool = field(
|
|
default=True,
|
|
metadata={
|
|
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
|
},
|
|
)
|
|
source_prefix: Optional[str] = field(
|
|
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
|
)
|
|
forced_bos_token: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
|
|
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
|
|
"needs to be the target language token.(Usually it is the target language token)"
|
|
},
|
|
)
|
|
do_lowercase: bool = field(
|
|
default=False,
|
|
metadata={
|
|
'help':'Whether to process input into lowercase'
|
|
}
|
|
)
|
|
append_another_bos: bool = field(
|
|
default=False,
|
|
metadata={
|
|
'help':'Whether to add another bos'
|
|
}
|
|
)
|
|
context_column: Optional[str] = field(
|
|
default="context",
|
|
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
|
|
)
|
|
question_column: Optional[str] = field(
|
|
default="question",
|
|
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
|
)
|
|
answer_column: Optional[str] = field(
|
|
default="answers",
|
|
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
|
)
|
|
max_task_num: Optional[int] = field(
|
|
default=30,
|
|
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
|
)
|
|
qa_task_type_num: Optional[int] = field(
|
|
default=4,
|
|
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
|
)
|
|
prompt_number: Optional[int] = field(
|
|
default=100,
|
|
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
|
)
|
|
add_task_prompt: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
|
)
|
|
reload_from_trained_prompt: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "whether to reload prompt from trained prompt"},
|
|
)
|
|
trained_prompt_path:Optional[str] = field(
|
|
default = None,
|
|
metadata={
|
|
'help':'the path storing trained prompt embedding'
|
|
}
|
|
)
|
|
load_from_format_task_id: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
|
|
)
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
|
raise ValueError("Need either a dataset name or a training/validation file.")
|
|
# elif self.source_lang is None or self.target_lang is None:
|
|
# raise ValueError("Need to specify the source language and the target language.")
|
|
|
|
if self.train_file is not None:
|
|
extension = self.train_file.split(".")[-1]
|
|
# assert extension == "json", "`train_file` should be a json file."
|
|
if self.validation_file is not None:
|
|
extension = self.validation_file.split(".")[-1]
|
|
# assert extension == "json", "`validation_file` should be a json file."
|
|
if self.val_max_target_length is None:
|
|
self.val_max_target_length = self.max_target_length
|
|
|
|
|
|
# +
|
|
def main():
|
|
|
|
|
|
|
|
def preprocess_validation_function(examples):
|
|
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
|
|
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
|
|
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
|
# Setup the tokenizer for targets
|
|
model_inputs["example_id"] = []
|
|
|
|
for i in range(len(model_inputs["input_ids"])):
|
|
# One example can give several spans, this is the index of the example containing this span of text.
|
|
sample_index = i #sample_mapping[i]
|
|
model_inputs["example_id"].append(examples["id"][sample_index])
|
|
with tokenizer.as_target_tokenizer():
|
|
|
|
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
|
|
|
|
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
|
# padding in the loss.
|
|
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
|
labels["input_ids"] = [
|
|
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
|
]
|
|
model_inputs["labels"] = labels["input_ids"]
|
|
|
|
if True:
|
|
pass
|
|
# logger.info(f'Loading task {data_args.dataset_name} prompt')
|
|
# format_id = format2id[dataset2format[data_args.dataset_name]]
|
|
# task_id = task2id[data_args.dataset_name]
|
|
# format_prompt_ids = [- (i + 1) for i in range(format_id * data_args.prompt_number, (
|
|
# format_id + 1) * data_args.prompt_number)] # list(range(-(format_id * data_args.prompt_number+1), -((format_id + 1) * data_args.prompt_number+1)))
|
|
# task_prompt_id_start = len(format2id.keys()) * data_args.prompt_number
|
|
# logger.info('Prompt ids {}: {}'.format(task_prompt_id_start + task_id * data_args.prompt_number,
|
|
# task_prompt_id_start + (task_id + 1) * data_args.prompt_number))
|
|
# task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * data_args.prompt_number,
|
|
# task_prompt_id_start + (task_id + 1) * data_args.prompt_number)]
|
|
# input_ids = copy.deepcopy(
|
|
# [format_prompt_ids + task_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
|
|
#input_ids = copy.deepcopy([format_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
|
|
# model_inputs['input_ids'] = input_ids
|
|
# model_inputs['attention_mask'] = [[1] * data_args.prompt_number * 2 + attention_mask for attention_mask in
|
|
# model_inputs['attention_mask']]
|
|
|
|
# input_ids = copy.deepcopy([input])
|
|
return model_inputs
|
|
def compute_metrics(p: EvalPrediction):
|
|
return metric.compute(predictions=p.predictions, references=p.label_ids)
|
|
|
|
|
|
def save_prompt_embedding(model,path):
|
|
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
|
|
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
|
|
prompt_path = os.path.join(path,'prompt_embedding_info')
|
|
torch.save(save_prompt_info,prompt_path)
|
|
logger.info(f'Saving prompt embedding information to {prompt_path}')
|
|
|
|
|
|
|
|
def preprocess_function(examples):
|
|
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
|
|
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
|
|
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
|
# Setup the tokenizer for targets
|
|
with tokenizer.as_target_tokenizer():
|
|
|
|
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
|
|
|
|
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
|
# padding in the loss.
|
|
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
|
labels["input_ids"] = [
|
|
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
|
]
|
|
|
|
model_inputs["labels"] = labels["input_ids"]
|
|
|
|
return model_inputs
|
|
|
|
def save_load_diverse_sample(model,trainset):
|
|
with torch.no_grad():
|
|
device = 'cuda:0'
|
|
search_upbound = len(trainset)//4
|
|
query_idxs = [None]*30
|
|
keys = model.encoder.domain_keys
|
|
for idx,item in enumerate(trainset.select(range(search_upbound))):
|
|
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
|
|
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
|
|
return_dict=True)
|
|
result = torch.matmul(query,keys.t())
|
|
result = torch.topk(result,5).indices[0].cpu().numpy().tolist()
|
|
key_sel = None
|
|
for key_idx in result:
|
|
if query_idxs[key_idx] is None or len(query_idxs[key_idx])<3:
|
|
key_sel = key_idx
|
|
break
|
|
if key_sel is not None:
|
|
if query_idxs[key_sel] is None:
|
|
query_idxs[key_sel] = [idx]
|
|
else:
|
|
query_idxs[key_sel].append(idx)
|
|
total_idxs = []
|
|
for item in query_idxs:
|
|
try:
|
|
total_idxs.extend(item[:3])
|
|
except:
|
|
total_idxs.extend(random.sample(list(range(search_upbound,len(trainset))),3))
|
|
total_idxs = list(set(total_idxs))
|
|
total_idxs = random.sample(total_idxs,50)
|
|
sub_set = trainset.select(total_idxs)
|
|
|
|
features = []
|
|
for idx,item in enumerate(sub_set):
|
|
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
|
|
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
|
|
return_dict=True)
|
|
features.append(query.detach().cpu().numpy())
|
|
|
|
|
|
|
|
return sub_set,features
|
|
|
|
|
|
|
|
|
|
|
|
def dataset_name_to_func(dataset_name):
|
|
mapping = {
|
|
'squad': preprocess_sqaud_batch,
|
|
'squad_v2': preprocess_sqaud_batch,
|
|
'boolq': preprocess_boolq_batch,
|
|
'narrativeqa': preprocess_narrativeqa_batch,
|
|
'race': preprocess_race_batch,
|
|
'newsqa': preprocess_newsqa_batch,
|
|
'quoref': preprocess_sqaud_batch,
|
|
'ropes': preprocess_ropes_batch,
|
|
'drop': preprocess_drop_batch,
|
|
'nqopen': preprocess_sqaud_abstractive_batch,
|
|
# 'multirc': preprocess_boolq_batch,
|
|
'boolq_np': preprocess_boolq_batch,
|
|
'openbookqa': preprocess_openbookqa_batch,
|
|
'mctest': preprocess_race_batch,
|
|
'social_iqa': preprocess_social_iqa_batch,
|
|
'dream': preprocess_dream_batch,
|
|
}
|
|
return mapping[dataset_name]
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
|
|
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
|
else:
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
|
|
dataset_name_to_metric = {
|
|
'squad': 'metric/squad_v1_local/squad_v1_local.py',
|
|
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
|
|
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
|
|
'boolq': 'metric/accuracy.py',
|
|
'narrativeqa': 'metric/rouge_local/rouge_metric.py',
|
|
'race': 'metric/accuracy.py',
|
|
'quoref': 'metric/squad_v1_local/squad_v1_local.py',
|
|
'ropes': 'metric/squad_v1_local/squad_v1_local.py',
|
|
'drop': 'metric/squad_v1_local/squad_v1_local.py',
|
|
'nqopen': 'metric/squad_v1_local/squad_v1_local.py',
|
|
'boolq_np': 'metric/accuracy.py',
|
|
'openbookqa': 'metric/accuracy.py',
|
|
'mctest': 'metric/accuracy.py',
|
|
'social_iqa': 'metric/accuracy.py',
|
|
'dream': 'metric/accuracy.py',
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
datefmt="%m/%d/%Y %H:%M:%S",
|
|
handlers=[logging.StreamHandler(sys.stdout)],
|
|
)
|
|
|
|
log_level = training_args.get_process_log_level()
|
|
logger.setLevel(log_level)
|
|
datasets.utils.logging.set_verbosity(log_level)
|
|
transformers.utils.logging.set_verbosity(log_level)
|
|
transformers.utils.logging.enable_default_handler()
|
|
transformers.utils.logging.enable_explicit_format()
|
|
|
|
logger.warning(
|
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
|
)
|
|
logger.info(f"Training/evaluation parameters {training_args}")
|
|
|
|
if data_args.source_prefix is None and model_args.model_name_or_path in [
|
|
"t5-small",
|
|
"t5-base",
|
|
"t5-large",
|
|
"t5-3b",
|
|
"t5-11b",
|
|
]:
|
|
logger.warning(
|
|
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
|
|
"`--source_prefix 'translate English to German: ' `"
|
|
)
|
|
|
|
# Detecting last checkpoint.
|
|
last_checkpoint = None
|
|
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
|
|
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
|
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
|
raise ValueError(
|
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
|
"Use --overwrite_output_dir to overcome."
|
|
)
|
|
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
|
logger.info(
|
|
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
|
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
|
)
|
|
|
|
# Set seed before initializing model.
|
|
set_seed(training_args.seed)
|
|
|
|
config = AutoConfig.from_pretrained(
|
|
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
|
cache_dir=model_args.cache_dir,
|
|
revision=model_args.model_revision,
|
|
use_auth_token=True if model_args.use_auth_token else None,
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
|
cache_dir=model_args.cache_dir,
|
|
use_fast=model_args.use_fast_tokenizer,
|
|
revision=model_args.model_revision,
|
|
use_auth_token=True if model_args.use_auth_token else None,
|
|
)
|
|
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
|
|
# '[OPTIONS]'])
|
|
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
|
|
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
|
|
'[OPTIONS]']}
|
|
|
|
tokenizer.add_tokens(tokens_to_add)
|
|
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
|
added_tokens = tokenizer.get_added_vocab()
|
|
logger.info('Added tokens: {}'.format(added_tokens))
|
|
|
|
model = PromptT5.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
|
config=config,
|
|
cache_dir=model_args.cache_dir,
|
|
revision=model_args.model_revision,
|
|
use_auth_token=True if model_args.use_auth_token else None,
|
|
# task_num = data_args.max_task_num,
|
|
# prompt_num = data_args.prompt_number,
|
|
# format_num = data_args.qa_task_type_num,
|
|
# add_task_prompt = False
|
|
)
|
|
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
|
|
#reload format specific task-prompt for newly involved task
|
|
#format_prompts###task_promptsf
|
|
|
|
data_args.reload_from_trained_prompt = False#@
|
|
data_args.load_from_format_task_id = False#@
|
|
### before pretrain come !!!!!!
|
|
|
|
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
|
|
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
|
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
|
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
|
|
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
|
|
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
|
|
# print(dataset2format[data_args.dataset_name])
|
|
# print(data_args.dataset_name)
|
|
elif data_args.reload_from_trained_prompt:
|
|
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
|
|
prompt_info = torch.load(data_args.trained_prompt_path)
|
|
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
|
|
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
|
|
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
|
|
|
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
|
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
|
|
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
|
|
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
|
|
format_id = format2id[dataset2format[data_args.dataset_name]]
|
|
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
|
|
logger.info(
|
|
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
|
|
|
|
# Set decoder_start_token_id
|
|
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
|
if isinstance(tokenizer, MBartTokenizer):
|
|
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
|
|
else:
|
|
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
|
|
|
|
if model.config.decoder_start_token_id is None:
|
|
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
|
|
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
|
|
|
|
|
if training_args.local_rank == -1 or training_args.no_cuda:
|
|
device = torch.device("cuda")
|
|
n_gpu = torch.cuda.device_count()
|
|
|
|
|
|
|
|
|
|
|
|
# Temporarily set max_target_length for training.
|
|
max_target_length = data_args.max_target_length
|
|
padding = "max_length" if data_args.pad_to_max_length else False
|
|
|
|
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
|
logger.warning(
|
|
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
|
|
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
|
|
)
|
|
|
|
|
|
question_column = data_args.question_column
|
|
context_column = data_args.context_column
|
|
answer_column = data_args.answer_column
|
|
# import random
|
|
if data_args.max_source_length > tokenizer.model_max_length:
|
|
logger.warning(
|
|
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
|
|
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
|
)
|
|
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
|
|
# Data collator
|
|
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
|
if data_args.pad_to_max_length:
|
|
data_collator = default_data_collator
|
|
else:
|
|
data_collator = DataCollatorForSeq2Seq(
|
|
tokenizer,
|
|
model=model,
|
|
label_pad_token_id=label_pad_token_id,
|
|
pad_to_multiple_of=8 if training_args.fp16 else None,
|
|
)
|
|
|
|
#start
|
|
train_dataloaders = {}
|
|
eval_dataloaders = {}
|
|
replay_dataloaders = {}
|
|
all_replay = None
|
|
|
|
for ds_name in ['squad','narrativeqa','race','newsqa','quoref','drop','nqopen','openbookqa','mctest','social_iqa','dream']:
|
|
eval_dataloaders[ds_name] = (load_from_disk("./oursnotask/{}-eval.hf".format(ds_name)),load_from_disk("./oursnotask/{}-evalex.hf".format(ds_name)))
|
|
|
|
|
|
pre_tasks = []
|
|
pre_general = []
|
|
pre_test = []
|
|
max_length = (
|
|
training_args.generation_max_length
|
|
if training_args.generation_max_length is not None
|
|
else data_args.val_max_target_length
|
|
)
|
|
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
|
|
|
task_sequence = ['squad','newsqa','narrativeqa','nqopen','race','openbookqa','mctest','social_iqa']
|
|
# task_sequence = ["woz.en","srl","sst","wikisql","squad"]
|
|
fileout = open("diana_log.txt",'w')
|
|
all_replay = None
|
|
all_features = []
|
|
all_ids = []
|
|
cluster_num=0
|
|
for cur_dataset in task_sequence:
|
|
p=200
|
|
if p>len(load_from_disk("./oursnotask/{}-train.hf".format(cur_dataset))):
|
|
p = len(load_from_disk("./oursnotask/{}-train.hf".format(cur_dataset)))
|
|
|
|
trainds = load_from_disk("./oursnotask/{}-train.hf".format(cur_dataset))
|
|
cluster_num+=5
|
|
pre_tasks.append(cur_dataset)
|
|
if cur_dataset==task_sequence[-1]:
|
|
pre_tasks.extend(["drop","quoref","dream"])
|
|
data_args.dataset_name = cur_dataset
|
|
logger.info("current_dataset:"+cur_dataset)
|
|
|
|
|
|
training_args.do_train = True
|
|
training_args.to_eval = False
|
|
metric = load_metric(dataset_name_to_metric[cur_dataset])
|
|
if all_replay is not None:
|
|
fused = datasets.concatenate_datasets([all_replay,trainds])
|
|
else:
|
|
fused = trainds
|
|
training_args.num_train_epochs = 5
|
|
model.encoder.reset_train_count()
|
|
|
|
trainer = QuestionAnsweringTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=fused,
|
|
eval_dataset=None,
|
|
eval_examples=None,
|
|
answer_column_name=answer_column,
|
|
dataset_name=data_args.dataset_name,
|
|
tokenizer=tokenizer,
|
|
data_collator=data_collator,
|
|
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
|
)
|
|
|
|
train_result = trainer.train()
|
|
|
|
if training_args.local_rank<=0:
|
|
save_set,features = save_load_diverse_sample(model,trainds)
|
|
|
|
if all_replay is None:
|
|
all_replay = save_set
|
|
else:
|
|
all_replay = datasets.concatenate_datasets([all_replay,save_set])
|
|
|
|
if all_features==[]:
|
|
all_features=features
|
|
else:
|
|
all_features.extend(features)
|
|
np.save("./all_features.npy",np.array(all_features))
|
|
|
|
all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset))
|
|
|
|
|
|
|
|
if training_args.local_rank!=-1:
|
|
torch.distributed.barrier()
|
|
all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset))
|
|
all_ids.extend([task2id[cur_dataset]]*50)
|
|
all_features=np.load("./all_features.npy").tolist()
|
|
|
|
|
|
|
|
model.encoder.reset_train_count()
|
|
trainer = QuestionAnsweringTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=all_replay,
|
|
eval_dataset=None,
|
|
eval_examples=None,
|
|
answer_column_name=answer_column,
|
|
dataset_name=data_args.dataset_name,
|
|
tokenizer=tokenizer,
|
|
data_collator=data_collator,
|
|
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
|
)
|
|
|
|
train_result = trainer.train()
|
|
|
|
model.encoder.add_negs(all_ids,all_features)
|
|
|
|
for pre_dataset in pre_tasks:
|
|
|
|
|
|
data_args.dataset_name = pre_dataset
|
|
|
|
metric = load_metric(dataset_name_to_metric[pre_dataset])
|
|
eval_dataset,eval_examples = eval_dataloaders[pre_dataset]
|
|
trainer = QuestionAnsweringTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=None,
|
|
eval_dataset=eval_dataset,
|
|
eval_examples=eval_examples,
|
|
answer_column_name=answer_column,
|
|
dataset_name=data_args.dataset_name,
|
|
tokenizer=tokenizer,
|
|
data_collator=data_collator,
|
|
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
|
)
|
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
logger.info("*** Evaluate:{} ***".format(data_args.dataset_name))
|
|
|
|
max_length, num_beams, ignore_keys_for_eval = None, None, None
|
|
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval")
|
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
|
|
|
trainer.log_metrics("eval", metrics)
|
|
trainer.save_metrics("eval", metrics)
|
|
if training_args.local_rank<=0:
|
|
try:
|
|
print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout)
|
|
print(metrics,file=fileout)
|
|
except:
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
|
|
if len(languages) > 0:
|
|
kwargs["language"] = languages
|
|
return None
|
|
|
|
|
|
# -
|
|
|
|
def _mp_fn(index):
|
|
# For xla_spawn (TPUs)
|
|
main()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|