1332 lines
54 KiB
Python
1332 lines
54 KiB
Python
import os
|
|
os.environ['NCCL_ASYNC_ERROR_HANDLING']='1'
|
|
import time
|
|
import torch
|
|
import copy,random
|
|
import sys
|
|
import json
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.utils.data import DataLoader
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
from typing import List, Optional, Tuple
|
|
from tqdm import tqdm
|
|
import warnings
|
|
warnings.filterwarnings("ignore")
|
|
# from data import QAData
|
|
from torch.cuda import amp
|
|
sys.path.append("../")
|
|
from contextlib import suppress as nullcontext
|
|
from transformers import AdamW, get_linear_schedule_with_warmup
|
|
from models.metascorenometa import T5ForConditionalGeneration as PromptT5
|
|
from dataset_processors import *
|
|
from metatrainernometa import QuestionAnsweringTrainer
|
|
from retrievermodel import init_biencoder_components
|
|
from rerankermodel.encoder import BertEncoder_For_CrossEncoder
|
|
from loadscorea import *
|
|
from hintpreprocess import *
|
|
from transformers.trainer_callback import (
|
|
CallbackHandler,
|
|
DefaultFlowCallback,
|
|
PrinterCallback,
|
|
ProgressCallback,
|
|
TrainerCallback,
|
|
TrainerControl,
|
|
TrainerState,
|
|
)
|
|
|
|
|
|
import datasets
|
|
import numpy as np
|
|
from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets
|
|
import os
|
|
|
|
from functools import partial
|
|
os.environ["WANDB_DISABLED"] = "true"
|
|
import transformers
|
|
from transformers import (
|
|
AutoConfig,
|
|
AutoModelForSeq2SeqLM,
|
|
AutoTokenizer,
|
|
DataCollatorForSeq2Seq,
|
|
HfArgumentParser,
|
|
M2M100Tokenizer,
|
|
MBart50Tokenizer,
|
|
EvalPrediction,
|
|
MBart50TokenizerFast,
|
|
MBartTokenizer,
|
|
MBartTokenizerFast,
|
|
Seq2SeqTrainer,
|
|
Seq2SeqTrainingArguments,
|
|
default_data_collator,
|
|
set_seed,
|
|
)
|
|
from pathlib import Path
|
|
from transformers.trainer_utils import get_last_checkpoint
|
|
from transformers.utils import check_min_version
|
|
from transformers.utils.versions import require_version
|
|
|
|
def get_model_obj(model: nn.Module):
|
|
return model.module if hasattr(model, "module") else model
|
|
def preprocess_function_eval(examples,format_name):
|
|
if True:
|
|
preprocess_fn = preprocess_proqa_eval
|
|
if "hintret" in examples.keys():
|
|
k_ = "hintret"
|
|
elif "hintrer" in examples.keys():
|
|
k_ = "hintrer"
|
|
else:
|
|
k_ = "hintfirst"
|
|
inputs, targets,hints= preprocess_fn(examples, "input","output",k_,format_name=format_name)
|
|
model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)
|
|
hints_inputs = tokenizer(hints, max_length=max_source_length, padding=padding, truncation=True)
|
|
model_inputs["input_ids"] = [xx[:-1]+[32108]+yy[:-1]+[1] for xx,yy in zip(model_inputs["input_ids"],hints_inputs["input_ids"])]
|
|
model_inputs["attention_mask"] = [xx[:-1]+[1]+yy[:-1]+[xx[-1]] for xx,yy in zip(model_inputs["attention_mask"],hints_inputs["attention_mask"])]
|
|
# sample_ids =list(range(len(inputs)))
|
|
# sample_ids = [5000+_+dataset2id[tsname]*1000000 for _ in sample_ids]
|
|
|
|
with tokenizer.as_target_tokenizer():
|
|
|
|
labels = tokenizer(targets, max_length=128, padding=padding, truncation=True)
|
|
|
|
if padding == "max_length":
|
|
labels["input_ids"] = [
|
|
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
|
]
|
|
|
|
model_inputs["labels"] = labels["input_ids"]
|
|
|
|
format_id = format2id[dataset2format[tsname]]
|
|
meta_ids = [- (i + 1) for i in range(10)]*5
|
|
task_prompt_id_start = len(format2id.keys()) * 40
|
|
task_id = task2id[tsname]
|
|
# task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start,
|
|
# task_prompt_id_start + 10)]*5
|
|
|
|
input_ids = copy.deepcopy(
|
|
[input_ids for input_ids in model_inputs['input_ids']])
|
|
|
|
|
|
|
|
model_inputs['input_ids'] = input_ids # [format_prompt_ids+input_ids for input_ids in model_inputs['input_ids']]
|
|
|
|
model_inputs['attention_mask'] = [attention_mask for attention_mask in
|
|
model_inputs['attention_mask']]
|
|
# model_inputs["sample_id"] = sample_ids
|
|
|
|
return model_inputs
|
|
|
|
|
|
def preprocess_function(examples):
|
|
if True:
|
|
preprocess_fn = preprocess_proqa
|
|
inputs, targets,hints = preprocess_fn(examples, "input","output","hint",format_name=format_name)
|
|
model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)
|
|
hints_inputs = tokenizer(hints, max_length=max_source_length, padding=padding, truncation=True)
|
|
model_inputs["input_ids"] = [xx[:-1]+[32108]+yy[:-1]+[1] for xx,yy in zip(model_inputs["input_ids"],hints_inputs["input_ids"])]
|
|
model_inputs["attention_mask"] = [xx[:-1]+[1]+yy[:-1]+[xx[-1]] for xx,yy in zip(model_inputs["attention_mask"],hints_inputs["attention_mask"])]
|
|
# sample_ids =list(range(len(inputs)))
|
|
# print(len(inputs),'len',format_name)
|
|
# sample_ids = [_+dataset2id[tsname]*1000000 for _ in sample_ids]
|
|
|
|
# print(sample_ids)
|
|
|
|
with tokenizer.as_target_tokenizer():
|
|
|
|
labels = tokenizer(targets, max_length=128, padding=padding, truncation=True)
|
|
|
|
if padding == "max_length":
|
|
labels["input_ids"] = [
|
|
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
|
]
|
|
|
|
model_inputs["labels"] = labels["input_ids"]
|
|
|
|
format_id = format2id[dataset2format[tsname]]
|
|
format_prompt_ids = [- (i + 1) for i in range(format_id * 40, (
|
|
format_id + 1) * 40)]
|
|
task_prompt_id_start = len(format2id.keys()) * 40
|
|
task_id = task2id[tsname]
|
|
|
|
input_ids = copy.deepcopy(
|
|
[format_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
|
|
|
|
seps = []
|
|
|
|
|
|
model_inputs['input_ids'] = input_ids # [format_prompt_ids+input_ids for input_ids in model_inputs['input_ids']]
|
|
|
|
model_inputs['attention_mask'] = [[1] * 40 + attention_mask for attention_mask in
|
|
model_inputs['attention_mask']]
|
|
# model_inputs["sample_id"] = sample_ids
|
|
# ph1 = list(zip(inputs_singles[0],inputs_singles[1],inputs_singles[2],inputs_singles[3]))
|
|
# ph2= list(zip(masks_singles[0],masks_singles[1],masks_singles[2],masks_singles[3]))
|
|
# ph1 = inputs_singles[0]
|
|
# ph2 = masks_singles[0]
|
|
|
|
|
|
|
|
return model_inputs
|
|
|
|
|
|
|
|
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
|
|
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
|
|
format2id = {'extractive':0,'abstractive':1,'bool':3,'multichoice':2}
|
|
format2dataset = {
|
|
'extractive':['squad1_1','squad2','extractive','newsqa','quoref','ropes','adversarialqa_dbert_dev','adversarialqa_dbidaf_dev','adversarialqa_droberta_dev','record_extractive'],
|
|
'abstractive':['narrativeqa_dev','abstractive','natural_questions_with_dpr_para','drop','qaconv','tweetqa'],
|
|
'multichoice':['race_string','multichoice','openbookqa','mctest_corrected_the_separator','social_iqa','commonsenseqa','qasc','physical_iqa','winogrande_xl','onestopqa_advanced','onestopqa_elementry','onestopqa_intermediate','prost_multiple_choice_with_no_context','dream','processbank_test','cosmosqa','mcscript','mcscript2','quail','reclor','measuring_massive_multitask_language_understanding','head_qa_en_test','race_c','arc_hard','arc_easy'],
|
|
'bool':['boolq','bool','boolq_np','multirc','strategyqa','pubmedqa_pqal_short_ans']
|
|
}
|
|
dataset2format= {}
|
|
task2id = {}
|
|
for k,vs in format2dataset.items():
|
|
for v in vs:
|
|
dataset2format[v] = k
|
|
|
|
task2id = {'squad1_1': 0, 'extractive': 1, 'narrativeqa_dev': 2, 'abstractive': 3, 'race_string': 4, 'multichoice': 5, 'boolq': 6, 'bool': 7, 'newsqa':8,'quoref':9,'ropes':10,'drop':11,'natural_questions_with_dpr_para':12,'boolq_np':13,'openbookqa':14,'mctest_corrected_the_separator':15,'social_iqa':16,'dream':17}
|
|
num_now = 18
|
|
for k in format2dataset.keys():
|
|
for item in format2dataset[k]:
|
|
if not item in task2id.keys():
|
|
task2id[item]=num_now
|
|
num_now+=1
|
|
|
|
class StopCallback(TrainerCallback):
|
|
def on_epoch_end(self, args, state, control, logs=None, **kwargs):
|
|
control.should_training_stop = True
|
|
callbacker = StopCallback()
|
|
global max_target_length
|
|
max_target_length = 128
|
|
global padding
|
|
padding = False
|
|
|
|
label_pad_token_id = -100
|
|
|
|
split_gen = False
|
|
|
|
selected = {"squad1_1":8164,"squad2":130319,"narrativeqa_dev":3567,"mctest_corrected_the_separator":342,"race_string":14536,"arc_hard":317,"arc_easy":395,"boolq":765,"openbookqa":580,"newsqa":445,"quoref":1574,"ropes":1272,"drop":5214,"natural_questions_with_dpr_para":32590,"commonsenseqa":1034,"qasc":653,"physical_iqa":494,"social_iqa":2077,"winogrande_xl":2634,"multirc":290,"boolq_np":923}
|
|
|
|
|
|
dataset_files="onestopqa_advanced,onestopqa_elementry,onestopqa_intermediate,prost_multiple_choice_with_no_context,dream,processbank_test,cosmosqa,mcscript,mcscript2,quail,reclor,measuring_massive_multitask_language_understanding,head_qa_en_test,race_c,pubmedqa_pqal_short_ans,strategyqa,tweetqa,qaconv,record_extractive,adversarialqa_dbert_dev,adversarialqa_dbidaf_dev,adversarialqa_droberta_dev".split(",")
|
|
global dataset2id
|
|
dataset2id = {}
|
|
for d_ix,item in enumerate(dataset_files):
|
|
dataset2id[item]=d_ix
|
|
#all seen tasks;unseen currently unsupported
|
|
|
|
|
|
@dataclass
|
|
class ModelArguments:
|
|
"""
|
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
|
"""
|
|
|
|
model_name_or_path: str = field(
|
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
|
)
|
|
config_name: Optional[str] = field(
|
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
|
)
|
|
tokenizer_name: Optional[str] = field(
|
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
|
)
|
|
cache_dir: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
|
)
|
|
use_fast_tokenizer: bool = field(
|
|
default=True,
|
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
|
)
|
|
model_revision: str = field(
|
|
default="main",
|
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
|
)
|
|
use_auth_token: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
|
"with private models)."
|
|
},
|
|
)
|
|
fix_t5: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "whether to fix the main parameters of t5"}
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class DataTrainingArguments:
|
|
"""
|
|
Arguments pertaining to what data we are going to input our model for training and eval.
|
|
"""
|
|
|
|
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
|
|
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
|
|
|
|
dataset_name: Optional[str] = field(
|
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
|
)
|
|
dataset_config_name: Optional[str] = field(
|
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
|
)
|
|
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
|
|
validation_file: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
|
|
"a jsonlines file."
|
|
},
|
|
)
|
|
test_file: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
|
|
},
|
|
)
|
|
overwrite_cache: bool = field(
|
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
|
)
|
|
after_train: Optional[str] = field(default=None)
|
|
preprocessing_num_workers: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "The number of processes to use for the preprocessing."},
|
|
)
|
|
max_source_length: Optional[int] = field(
|
|
default=1024,
|
|
metadata={
|
|
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded."
|
|
},
|
|
)
|
|
max_target_length: Optional[int] = field(
|
|
default=128,
|
|
metadata={
|
|
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded."
|
|
},
|
|
)
|
|
val_max_target_length: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
|
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
|
"during ``evaluate`` and ``predict``."
|
|
},
|
|
)
|
|
max_seq_length: int = field(
|
|
default=384,
|
|
metadata={
|
|
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded."
|
|
},
|
|
)
|
|
pad_to_max_length: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "Whether to pad all samples to model maximum sentence length. "
|
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
|
"efficient on GPU but very bad for TPU."
|
|
},
|
|
)
|
|
max_train_samples: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
|
"value if set."
|
|
},
|
|
)
|
|
max_eval_samples: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
|
"value if set."
|
|
},
|
|
)
|
|
max_predict_samples: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
|
"value if set."
|
|
},
|
|
)
|
|
num_beams: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
|
"which is used during ``evaluate`` and ``predict``."
|
|
},
|
|
)
|
|
ignore_pad_token_for_loss: bool = field(
|
|
default=True,
|
|
metadata={
|
|
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
|
},
|
|
)
|
|
source_prefix: Optional[str] = field(
|
|
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
|
)
|
|
forced_bos_token: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
|
|
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
|
|
"needs to be the target language token.(Usually it is the target language token)"
|
|
},
|
|
)
|
|
do_lowercase: bool = field(
|
|
default=False,
|
|
metadata={
|
|
'help':'Whether to process input into lowercase'
|
|
}
|
|
)
|
|
append_another_bos: bool = field(
|
|
default=False,
|
|
metadata={
|
|
'help':'Whether to add another bos'
|
|
}
|
|
)
|
|
context_column: Optional[str] = field(
|
|
default="context",
|
|
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
|
|
)
|
|
question_column: Optional[str] = field(
|
|
default="question",
|
|
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
|
)
|
|
answer_column: Optional[str] = field(
|
|
default="answers",
|
|
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
|
)
|
|
max_task_num: Optional[int] = field(
|
|
default=30,
|
|
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
|
)
|
|
qa_task_type_num: Optional[int] = field(
|
|
default=4,
|
|
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
|
)
|
|
prompt_number: Optional[int] = field(
|
|
default=100,
|
|
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
|
)
|
|
add_task_prompt: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
|
)
|
|
reload_from_trained_prompt: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "whether to reload prompt from trained prompt"},
|
|
)
|
|
trained_prompt_path:Optional[str] = field(
|
|
default = None,
|
|
metadata={
|
|
'help':'the path storing trained prompt embedding'
|
|
}
|
|
)
|
|
load_from_format_task_id: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
|
|
)
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
|
raise ValueError("Need either a dataset name or a training/validation file.")
|
|
# elif self.source_lang is None or self.target_lang is None:
|
|
# raise ValueError("Need to specify the source language and the target language.")
|
|
|
|
if self.train_file is not None:
|
|
extension = self.train_file.split(".")[-1]
|
|
# assert extension == "json", "`train_file` should be a json file."
|
|
if self.validation_file is not None:
|
|
extension = self.validation_file.split(".")[-1]
|
|
# assert extension == "json", "`validation_file` should be a json file."
|
|
if self.val_max_target_length is None:
|
|
self.val_max_target_length = self.max_target_length
|
|
|
|
|
|
def evaluate_model(eval_model, priority_level,format_name,training_args,data_collator,epoch_id):
|
|
|
|
device_id = training_args.local_rank
|
|
if device_id == -1:
|
|
device_id = 0
|
|
if training_args.local_rank == -1:
|
|
eval_model = eval_model.to(torch.device("cuda", device_id))
|
|
# model_enc = model_enc.to(torch.device("cuda", device_id))
|
|
else:
|
|
eval_model = eval_model.to(torch.device("cuda", device_id))
|
|
# model_enc = model_enc.to(torch.device("cuda", device_id))
|
|
eval_model = nn.parallel.DistributedDataParallel(
|
|
eval_model,
|
|
device_ids=[training_args.local_rank] if training_args.n_gpu != 0 else None,
|
|
output_device=training_args.local_rank if training_args.n_gpu else None,
|
|
)
|
|
|
|
for item in dataset_files:
|
|
if not item in format2dataset[format_name]:
|
|
continue
|
|
dataset_ret = load_from_disk("./epoch_datas{}/{}-reteval.hf".format(str(epoch_id),item))
|
|
dataset_rer = load_from_disk("./epoch_datas{}/{}-rereval.hf".format(str(epoch_id),item))
|
|
dataset_first = load_from_disk("./epoch_datas{}/{}-firsteval.hf".format(str(epoch_id),item))
|
|
for (eval_ds,source) in [(dataset_first,"first"),(dataset_ret,"ret"),(dataset_rer,"rer")]:
|
|
if training_args.local_rank == -1:
|
|
data_loader_eval = DataLoader(dataset=eval_ds, shuffle=False,batch_size=training_args.per_device_train_batch_size,collate_fn=data_collator)#, sampler=train_sampler)
|
|
else:
|
|
sampler = DistributedSampler(
|
|
eval_ds,
|
|
num_replicas=torch.distributed.get_world_size(),
|
|
rank=training_args.local_rank,
|
|
seed=training_args.seed,
|
|
)
|
|
data_loader_eval = DataLoader(dataset=eval_ds, shuffle=False,batch_size=training_args.per_device_train_batch_size,collate_fn=data_collator,sampler=sampler,drop_last=False)
|
|
for data in data_loader_eval:
|
|
labels = data.pop("labels")
|
|
data["eval_labels"]=labels
|
|
data = {x: data[x].to(torch.device("cuda", device_id)) for x in data if data[x] is not None}
|
|
if True:
|
|
if source!="first":
|
|
out = eval_model.module.eval_mode(**data)
|
|
priority_level[format_name][source]+=out.mean().item()
|
|
|
|
else:
|
|
out = eval_model.module.eval_mode16(**data)
|
|
priority_level[format_name]["qa"]+=out.mean().item()
|
|
break
|
|
|
|
|
|
return priority_level
|
|
|
|
def trainbc(bencoder,cencoder,dataset,training_args,fwd,qaret,qarer):
|
|
no_decay = ["bias", "LayerNorm.weight"]
|
|
optimizer_grouped_parameters = [
|
|
{"params": [p for n, p in bencoder.named_parameters() if not any(
|
|
nd in n for nd in no_decay)], "weight_decay": 0.01},
|
|
{"params": [p for n, p in bencoder.named_parameters() if any(
|
|
nd in n for nd in no_decay)], "weight_decay": 0.0},
|
|
{"params": [p for n, p in cencoder.named_parameters() if not any(
|
|
nd in n for nd in no_decay)], "weight_decay": 0.01},
|
|
{"params": [p for n, p in cencoder.named_parameters() if any(
|
|
nd in n for nd in no_decay)], "weight_decay": 0.0}
|
|
]
|
|
|
|
optimizer = AdamW(
|
|
optimizer_grouped_parameters,
|
|
lr=1e-5,
|
|
# eps=args.adam_epsilon
|
|
)
|
|
|
|
device_id = training_args.local_rank
|
|
loss_fct = torch.nn.KLDivLoss()
|
|
if device_id==-1:
|
|
device_id=0
|
|
bs_size = 2
|
|
|
|
offset = []
|
|
offset2 = []
|
|
offset3 = []
|
|
scaler = amp.GradScaler(enabled=True)
|
|
for itf in range(bs_size):
|
|
offset.extend([itf*64]*16)
|
|
offset2.extend([itf*16]*4)
|
|
offset3.extend([itf*64]*4)
|
|
offset = torch.Tensor(offset).long().to(torch.device("cuda", device_id))
|
|
offset2 = torch.Tensor(offset2).long().to(torch.device("cuda", device_id))
|
|
offset3 = torch.Tensor(offset3).long().to(torch.device("cuda", device_id))
|
|
select_range=640
|
|
|
|
if len(dataset)<select_range:
|
|
select_range=len(dataset)
|
|
if training_args.local_rank == -1:
|
|
data_loader_train = DataLoader(dataset=dataset.select(range(100)), batch_size=bs_size,collate_fn=default_data_collator)
|
|
else:
|
|
sampler = DistributedSampler(
|
|
dataset.select(range(select_range)),
|
|
num_replicas=torch.distributed.get_world_size(),
|
|
rank=training_args.local_rank,
|
|
seed=training_args.seed,
|
|
)
|
|
data_loader_train = DataLoader(dataset=dataset,batch_size=bs_size,sampler=sampler,drop_last=False,collate_fn=default_data_collator)
|
|
for i, data in enumerate(data_loader_train):
|
|
if data["query_ids"].size(0)!=bs_size:
|
|
bs_size = data["query_ids"].size(0)
|
|
offset = []
|
|
offset2 = []
|
|
offset3 = []
|
|
for itf in range(bs_size):
|
|
offset.extend([itf*64]*16)
|
|
offset2.extend([itf*16]*4)
|
|
offset3.extend([itf*64]*4)
|
|
offset = torch.Tensor(offset).long().to(torch.device("cuda", device_id))
|
|
offset2 = torch.Tensor(offset2).long().to(torch.device("cuda", device_id))
|
|
offset3 = torch.Tensor(offset3).long().to(torch.device("cuda", device_id))
|
|
data["query_ids"] = data["query_ids"][:,0,:].squeeze(1)
|
|
data["query_attentions"] = data["query_attentions"][:,0,:].squeeze(1)
|
|
data = {x: data[x].to(torch.device("cuda", device_id)) for x in data if data[x] is not None}
|
|
with torch.no_grad():
|
|
bmodel_out = bencoder(
|
|
data["query_ids"].view(-1,112),#.squeeze(),
|
|
data["query_attentions"].view(-1,112),#.squeeze(),
|
|
data["ctx_ids"].view(-1,112),#.squeeze(),
|
|
data["ctx_attentions"].view(-1,112),#.squeeze(),
|
|
)
|
|
|
|
score_mask = (data["sub_ids"]>=999).int().mul(-999)
|
|
local_q_vector, local_ctx_vectors = bmodel_out
|
|
local_q_vector = local_q_vector.view(-1,768)
|
|
local_ctx_vectors = local_ctx_vectors.view(-1,64,768)
|
|
|
|
|
|
sim_scores = torch.bmm(
|
|
local_q_vector.unsqueeze(1), torch.transpose(local_ctx_vectors, 1, 2)
|
|
).squeeze(1)
|
|
|
|
sim_scores = sim_scores.add(score_mask)
|
|
|
|
#(1*768)*(768*64)=(1*64)
|
|
|
|
sort_result, sort_idxs = sim_scores.topk(16)#sort(dot_prod_scores, dim=0, descending=True)
|
|
|
|
sort_idxs = sort_idxs.view(-1).add(offset)
|
|
kl_loss_qa1 = 0.0
|
|
kl_loss_qa2 = 0.0
|
|
|
|
if "qa_scores" in data.keys() and (qaret or qarer):
|
|
with torch.no_grad():
|
|
previous_ids = []
|
|
sids = data["sample_id"].cpu().numpy().tolist()
|
|
for sid in sids:
|
|
try:
|
|
previous_ids.extend(format_hints_ids[sid])
|
|
except:
|
|
previous_ids.extend([0,1,2,3])
|
|
previous_ids = torch.tensor(previous_ids).long().to(torch.device("cuda", device_id)).add(offset3)
|
|
previous_data = {}
|
|
for k_ in data.keys():
|
|
if len(data[k_].size())==3:
|
|
previous_data[k_] = torch.index_select(data[k_].view(bs_size*64,-1),dim=0,index=previous_ids).view(bs_size,4,-1).clone()
|
|
elif len(data[k_].size())==2:
|
|
if k_=="query_ids" or k_=="query_attentions":
|
|
previous_data[k_]=data[k_].clone()
|
|
elif k_!="qa_scores":
|
|
previous_data[k_] = torch.index_select(data[k_].view(-1),dim=0,index=previous_ids).view(bs_size,4).clone()
|
|
else:
|
|
previous_data[k_]=data[k_].clone()
|
|
else:
|
|
if k_=="sample_id":
|
|
previous_data[k_]=data[k_].clone()
|
|
else:
|
|
assert False
|
|
|
|
|
|
|
|
cmodel_previous = cencoder(
|
|
previous_data["cross_ids"].view(-1,144),
|
|
previous_data["cross_attentions"].view(-1,144),
|
|
previous_data["cross_ctxs"].view(-1,144),
|
|
).squeeze(-1).view(-1,4)
|
|
dmodel_previous = bencoder(
|
|
previous_data["query_ids"].view(-1,112),#.squeeze(),
|
|
previous_data["query_attentions"].view(-1,112),#.squeeze(),
|
|
previous_data["ctx_ids"].view(-1,112),#.squeeze(),
|
|
previous_data["ctx_attentions"].view(-1,112),#.squeeze(),
|
|
)
|
|
|
|
local_q_vector, local_ctx_vectors = dmodel_previous
|
|
local_q_vector = local_q_vector.view(-1,768)
|
|
local_ctx_vectors = local_ctx_vectors.view(-1,4,768)
|
|
sim_scores = torch.bmm(
|
|
local_q_vector.unsqueeze(1), torch.transpose(local_ctx_vectors, 1, 2)
|
|
).squeeze(1)
|
|
|
|
|
|
bi_score_previous = torch.nn.functional.log_softmax(sim_scores, dim=-1)
|
|
c_score_previous = torch.nn.functional.log_softmax(cmodel_previous, dim=-1)
|
|
qa_scores = torch.softmax(data["qa_scores"], dim=-1)
|
|
|
|
if qarer:
|
|
kl_loss_qa1 = loss_fct(c_score_previous, qa_scores)
|
|
if qaret:
|
|
kl_loss_qa2 = loss_fct(bi_score_previous,qa_scores)
|
|
kl_qa_all = kl_loss_qa1+kl_loss_qa2
|
|
scaler.scale(kl_qa_all).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
optimizer.zero_grad()
|
|
|
|
|
|
for k_ in data.keys():
|
|
if len(data[k_].size())==3:
|
|
data[k_] = torch.index_select(data[k_].view(bs_size*64,-1),dim=0,index=sort_idxs).view(bs_size,16,-1)
|
|
elif len(data[k_].size())==2:
|
|
if k_!="qa_scores" and k_!="query_ids" and k_!="query_attentions":
|
|
data[k_] = torch.index_select(data[k_].view(-1),dim=0,index=sort_idxs).view(bs_size,16)
|
|
else:
|
|
pass
|
|
else:
|
|
if k_=="sample_id":
|
|
pass
|
|
else:
|
|
print(k_)
|
|
print(data[k_])
|
|
assert False
|
|
|
|
|
|
|
|
|
|
cmodel_out = cencoder(
|
|
data["cross_ids"].view(-1,144),
|
|
data["cross_attentions"].view(-1,144),
|
|
data["cross_ctxs"].view(-1,144),
|
|
).squeeze(-1).view(-1,16)
|
|
dmodel_out = bencoder(
|
|
data["query_ids"].view(-1,112),#.squeeze(),
|
|
data["query_attentions"].view(-1,112),#.squeeze(),
|
|
data["ctx_ids"].view(-1,112),#.squeeze(),
|
|
data["ctx_attentions"].view(-1,112),#.squeeze(),
|
|
)
|
|
# score_mask = (data["sub_ids"]>=999).int().mul(-999)
|
|
local_q_vector, local_ctx_vectors = dmodel_out
|
|
local_q_vector = local_q_vector.view(-1,768)
|
|
local_ctx_vectors = local_ctx_vectors.view(-1,16,768)
|
|
sim_scores = torch.bmm(
|
|
local_q_vector.unsqueeze(1), torch.transpose(local_ctx_vectors, 1, 2)
|
|
).squeeze(1)
|
|
#print(cmodel_out)
|
|
#
|
|
if fwd:
|
|
bi_score = torch.softmax(sim_scores, dim=-1)
|
|
c_score = torch.nn.functional.log_softmax(cmodel_out, dim=-1)
|
|
kl_loss = loss_fct(c_score, bi_score)
|
|
else:
|
|
bi_score = torch.nn.functional.log_softmax(sim_scores, dim=-1)
|
|
c_score = torch.softmax(cmodel_out, dim=-1)
|
|
kl_loss = loss_fct(bi_score,c_score)
|
|
|
|
scaler.scale(kl_loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
optimizer.zero_grad()
|
|
|
|
# _,rerank_idxs = cmodel_out.topk(4)
|
|
|
|
|
|
|
|
|
|
# for bat in range(rerank_idxs.size(0)):
|
|
# sel = torch.index_select(data["sub_ids"][bat],dim=0,index=rerank_idxs[bat])
|
|
# ids = data["sample_id"][bat]
|
|
|
|
|
|
# format_hints_ids[ids.item()]=sel.cpu().numpy().tolist()
|
|
|
|
# for ix in range(rerank_idxs.size(0)):
|
|
# ids = data["sample_id"][ix].item()
|
|
# bscore = []
|
|
# cscore = []
|
|
# for item in rerank_idxs[ix]:
|
|
# bscore.append(bi_score[ix][item].item())
|
|
# cscore.append(c_score[ix][item].item())
|
|
# format_ret_scores[ids]=bscore
|
|
# format_rer_scores[ids]=cscore
|
|
|
|
def filthints(bencoder,cencoder,dataset,training_args,fwd,qaret,qarer,mode):
|
|
device_id = training_args.local_rank
|
|
if device_id==-1:
|
|
device_id=0
|
|
bencoder.eval()
|
|
cencoder.eval()
|
|
bs_size = 96
|
|
offset = []
|
|
offset2 = []
|
|
offset3 = []
|
|
for itf in range(bs_size):
|
|
offset.extend([itf*64]*16)
|
|
offset2.extend([itf*16]*4)
|
|
offset3.extend([itf*64]*4)
|
|
offset = torch.Tensor(offset).long().to(torch.device("cuda", device_id))
|
|
offset2 = torch.Tensor(offset2).long().to(torch.device("cuda", device_id))
|
|
offset3 = torch.Tensor(offset3).long().to(torch.device("cuda", device_id))
|
|
if training_args.local_rank == -1:
|
|
data_loader_train = DataLoader(dataset, shuffle=True,batch_size=bs_size,collate_fn=default_data_collator)
|
|
else:
|
|
sampler = DistributedSampler(
|
|
dataset,
|
|
num_replicas=torch.distributed.get_world_size(),
|
|
rank=training_args.local_rank,
|
|
seed=training_args.seed,
|
|
)
|
|
data_loader_train = DataLoader(dataset=dataset,batch_size=bs_size,sampler=sampler,collate_fn=default_data_collator,drop_last=False)
|
|
for i, data in tqdm(enumerate(data_loader_train)):
|
|
if data["query_ids"].size(0)!=bs_size:
|
|
bs_size = data["query_ids"].size(0)
|
|
offset = []
|
|
offset2 = []
|
|
offset3 = []
|
|
for itf in range(bs_size):
|
|
offset.extend([itf*64]*16)
|
|
offset2.extend([itf*16]*4)
|
|
offset3.extend([itf*64]*4)
|
|
offset = torch.Tensor(offset).long().to(torch.device("cuda", device_id))
|
|
offset2 = torch.Tensor(offset2).long().to(torch.device("cuda", device_id))
|
|
offset3 = torch.Tensor(offset3).long().to(torch.device("cuda", device_id))
|
|
torch.cuda.synchronize()
|
|
start = time.time()
|
|
data = {x: data[x].to(torch.device("cuda", device_id)) for x in data if data[x] is not None}
|
|
with torch.no_grad():
|
|
# torch.cuda.synchronize()
|
|
# print(time.time()-start,"trans_time")
|
|
bmodel_out = bencoder(
|
|
data["query_ids"].view(-1,112),#.squeeze(),
|
|
data["query_attentions"].view(-1,112),#.squeeze(),
|
|
data["ctx_ids"].view(-1,112),#.squeeze(),
|
|
data["ctx_attentions"].view(-1,112),#.squeeze(),
|
|
)
|
|
# torch.cuda.synchronize()
|
|
# print(time.time()-start,"bmodel_time")
|
|
score_mask = (data["sub_ids"]>=999).int().mul(-999)
|
|
local_q_vector, local_ctx_vectors = bmodel_out
|
|
local_q_vector = local_q_vector.view(-1,64,768)[:,0,:].view(-1,768)
|
|
local_ctx_vectors = local_ctx_vectors.view(-1,64,768)
|
|
|
|
|
|
sim_scores = torch.bmm(
|
|
local_q_vector.unsqueeze(1), torch.transpose(local_ctx_vectors, 1, 2)
|
|
).squeeze(1)
|
|
# print(sim_scores)
|
|
sim_scores = sim_scores.add(score_mask)
|
|
|
|
|
|
|
|
#(1*768)*(768*64)=(1*64)
|
|
sort_result, sort_idxs = sim_scores.topk(16)#sort(dot_prod_scores, dim=0, descending=True)
|
|
|
|
sort_idxs = sort_idxs.view(-1).add(offset)
|
|
|
|
|
|
|
|
for k_ in data.keys():
|
|
if len(data[k_].size())==3:
|
|
data[k_] = torch.index_select(data[k_].view(bs_size*64,-1),dim=0,index=sort_idxs).view(bs_size,16,-1)
|
|
elif len(data[k_].size())==2:
|
|
if k_!="qa_scores":
|
|
data[k_] = torch.index_select(data[k_].view(-1),dim=0,index=sort_idxs).view(bs_size,16)
|
|
else:
|
|
pass
|
|
else:
|
|
if k_=="sample_id":
|
|
pass
|
|
else:
|
|
print(k_)
|
|
print(data[k_])
|
|
assert False
|
|
# torch.cuda.synchronize()
|
|
# print(time.time()-start,"select16_time")
|
|
cmodel_out = cencoder(
|
|
data["cross_ids"].view(-1,144),
|
|
data["cross_attentions"].view(-1,144),
|
|
data["cross_ctxs"].view(-1,144),
|
|
).squeeze(-1).view(-1,16)
|
|
dmodel_out = bencoder(
|
|
data["query_ids"].view(-1,112),#.squeeze(),
|
|
data["query_attentions"].view(-1,112),#.squeeze(),
|
|
data["ctx_ids"].view(-1,112),#.squeeze(),
|
|
data["ctx_attentions"].view(-1,112),#.squeeze(),
|
|
)
|
|
# score_mask = (data["sub_ids"]>=999).int().mul(-999)
|
|
local_q_vector, local_ctx_vectors = dmodel_out
|
|
local_q_vector = local_q_vector.view(-1,16,768)[:,0,:].view(-1,768)
|
|
local_ctx_vectors = local_ctx_vectors.view(-1,16,768)
|
|
sim_scores = torch.bmm(
|
|
local_q_vector.unsqueeze(1), torch.transpose(local_ctx_vectors, 1, 2)
|
|
).squeeze(1)
|
|
|
|
|
|
bi_score = sim_scores
|
|
c_score = cmodel_out
|
|
|
|
|
|
_,rerank_idxs = cmodel_out.topk(4)
|
|
# torch.cuda.synchronize()
|
|
# print(time.time()-start,"cmodel_time")
|
|
|
|
|
|
|
|
for bat in range(rerank_idxs.size(0)):
|
|
sel = torch.index_select(data["sub_ids"][bat],dim=0,index=rerank_idxs[bat])
|
|
ids = data["sample_id"][bat]
|
|
|
|
if mode=="dev":
|
|
format_hints_ids[ids.item()]=sel.cpu().numpy().tolist()
|
|
else:
|
|
format_hints_ids_test[ids.item()]=sel.cpu().numpy().tolist()
|
|
|
|
|
|
# torch.cuda.synchronize()
|
|
# print(time.time()-start,"final_time")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_samples_unseen(format_name,format_hints,epoch_id):
|
|
format_ids_seq = []
|
|
|
|
all_format_size = 0
|
|
base_id = 2e5*format2id[format_name]
|
|
for item in dataset_files:
|
|
if item in format2dataset[format_name]:
|
|
app = load_from_disk("./processed/{}-evalex.hf".format(item))
|
|
all_format_size+=len(app)
|
|
format_ids_seq=[(i+base_id) for i in range(all_format_size)]
|
|
|
|
pdz = open("./textinput/{}test-glminput1.jsonl".format(format_name),'r',encoding='utf-8')
|
|
pdz_lines = pdz.readlines()
|
|
pdz_lines = [json.loads(item) for item in pdz_lines]
|
|
pdzs = [item['id'] for item in pdz_lines]
|
|
print(all_format_size,pdzs[-1])
|
|
assert pdzs[-1]+1==all_format_size
|
|
splits= {"extractive":1,"abstractive":1,"multichoice":1,"bool":1}
|
|
splitn =splits[format_name]
|
|
total_json = []
|
|
for idx in range(int(splitn)):
|
|
a = open("./plmresource/{}-glmout.json".format(format_name+"test"),'r',encoding='utf-8')
|
|
b = json.load(a)["data"]
|
|
total_json.extend(b)
|
|
top_line = []
|
|
|
|
single_hints = []
|
|
current_saved = []
|
|
top_line_count = 0
|
|
clk = 1
|
|
for left,right in zip(total_json,pdzs):
|
|
if right == clk:
|
|
read_ = []
|
|
idss = format_hints_ids[int(format_ids_seq[clk-1])]
|
|
for i_ in idss:
|
|
read_.append(current_saved[i_])
|
|
clk+=1
|
|
top_line.append(" ; ".join(read_))
|
|
current_saved = []
|
|
|
|
current_saved.append(left)
|
|
read_ = []
|
|
idss = format_hints_ids[int(format_ids_seq[clk-1])]
|
|
for i_ in idss:
|
|
read_.append(current_saved[i_])
|
|
top_line.append(" ; ".join(read_))
|
|
|
|
|
|
all_size = 0
|
|
start = 0
|
|
for item in dataset_files:
|
|
fm = dataset2format[item]
|
|
global tsname
|
|
tsname = item
|
|
if fm==format_name:
|
|
try:
|
|
data_path = "./data_process/data/{}/dev.json".format(item)
|
|
dataset = load_dataset("json", data_files=data_path)["train"]
|
|
except:
|
|
data_path = "./data_process/data/{}/test.json".format(item)
|
|
dataset = load_dataset("json", data_files=data_path)["train"]
|
|
all_size+=len(dataset)
|
|
end = start+len(dataset)
|
|
hints = top_line[start:end]
|
|
|
|
dataset = dataset.add_column("hintret",hints)
|
|
|
|
|
|
dataset_ret = dataset.map(
|
|
lambda x:preprocess_function_eval(x,format_name),
|
|
batched=True,
|
|
remove_columns=["input","output","hintret"],
|
|
load_from_cache_file=True,
|
|
desc="Running tokenizer on train dataset",
|
|
)
|
|
dataset_ret.save_to_disk("./final_eval/{}-eval.hf".format(item))
|
|
|
|
start = end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
def compute_metrics(p: EvalPrediction):
|
|
return metric.compute(predictions=p.predictions, references=p.label_ids)
|
|
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
|
|
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
|
else:
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
if 'same' in model_args.model_name_or_path:
|
|
task2id = {'squad': 0, 'extractive': 0, 'narrativeqa': 1, 'abstractive': 1, 'race': 2, 'multichoice': 2,
|
|
'boolq': 3, 'bool': 3, 'newsqa': 8, 'quoref': 9, 'ropes': 10, 'drop': 11, 'nqopen': 12,
|
|
'boolq_np': 13, 'openbookqa': 14, 'mctest': 15, 'social_iqa': 16, 'dream': 17}
|
|
else:
|
|
task2id = {'squad': 0, 'extractive': 1, 'narrativeqa': 2, 'abstractive': 3, 'race': 4, 'multichoice': 5,
|
|
'boolq': 6, 'bool': 7, 'newsqa': 8, 'quoref': 9, 'ropes': 10, 'drop': 11, 'nqopen': 12,
|
|
'boolq_np': 13, 'openbookqa': 14, 'mctest': 15, 'social_iqa': 16, 'dream': 17}
|
|
|
|
dataset_name_to_metric = {
|
|
'squad1_1': 'metric/squad_v1_local/squad_v1_local.py',
|
|
'squad2': 'metric/squad_v2_local/squad_v2_local.py',
|
|
'newsqa': 'metric/squad_v1_local/squad_v1_local.py',#'metric/squad_v2_local/squad_v2_local.py',
|
|
'boolq': 'metric/squad_v1_local/squad_v1_local.py',#'metric/accuracy.py',#
|
|
'narrativeqa_dev': 'metric/rouge_local/rouge_metric.py',
|
|
'race_string': 'metric/accuracy.py',
|
|
'quoref': 'metric/squad_v1_local/squad_v1_local.py',
|
|
'ropes': 'metric/squad_v1_local/squad_v1_local.py',
|
|
'drop': 'metric/squad_v1_local/squad_v1_local.py',
|
|
'natural_questions_with_dpr_para': 'metric/squad_v1_local/squad_v1_local.py',
|
|
'boolq_np': 'metric/squad_v1_local/squad_v1_local.py',
|
|
'openbookqa': 'metric/accuracy.py',
|
|
'arc_hard': 'metric/accuracy.py',
|
|
'arc_easy': 'metric/accuracy.py',
|
|
'mctest_corrected_the_separator': 'metric/accuracy.py',
|
|
'social_iqa': 'metric/accuracy.py',
|
|
'dream': 'metric/accuracy.py',
|
|
'commonsenseqa':'metric/accuracy.py',
|
|
'qasc':'metric/accuracy.py',
|
|
'physical_iqa':'metric/accuracy.py',
|
|
'winogrande_xl':'metric/accuracy.py',
|
|
'multirc':'metric/squad_v1_local/squad_v1_local.py',
|
|
'onestopqa_advanced':'metric/accuracy.py',
|
|
'onestopqa_elementry':'metric/accuracy.py',
|
|
'onestopqa_intermediate':'metric/accuracy.py',
|
|
'prost_multiple_choice_with_no_context':'metric/accuracy.py',
|
|
'processbank_test':'metric/accuracy.py',
|
|
'cosmosqa':'metric/accuracy.py',
|
|
'mcscript':'metric/accuracy.py',
|
|
'mcscript2':'metric/accuracy.py',
|
|
'quail':'metric/accuracy.py',
|
|
'reclor':'metric/accuracy.py',
|
|
'measuring_massive_multitask_language_understanding':'metric/accuracy.py',
|
|
'head_qa_en_test':'metric/accuracy.py',
|
|
'race_c':'metric/accuracy.py',
|
|
'pubmedqa_pqal_short_ans':'metric/squad_v1_local/squad_v1_local.py',#
|
|
'strategyqa':'metric/squad_v1_local/squad_v1_local.py',#
|
|
'tweetqa':'metric/bleu_local/bleu.py',
|
|
'qaconv':'metric/squad_v1_local/squad_v1_local.py',
|
|
'record_extractive':'metric/squad_v1_local/squad_v1_local.py',
|
|
'adversarialqa_dbert_dev':'metric/squad_v1_local/squad_v1_local.py',
|
|
'adversarialqa_dbidaf_dev':'metric/squad_v1_local/squad_v1_local.py',
|
|
'adversarialqa_droberta_dev':'metric/squad_v1_local/squad_v1_local.py'
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# Detecting last checkpoint.
|
|
last_checkpoint = None
|
|
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
|
|
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
|
|
|
|
|
# Set seed before initializing model.
|
|
set_seed(training_args.seed)
|
|
|
|
config = AutoConfig.from_pretrained(
|
|
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
|
cache_dir=model_args.cache_dir,
|
|
revision=model_args.model_revision,
|
|
use_auth_token=True if model_args.use_auth_token else None,
|
|
)
|
|
global tokenizer
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
|
cache_dir=model_args.cache_dir,
|
|
use_fast=model_args.use_fast_tokenizer,
|
|
revision=model_args.model_revision,
|
|
use_auth_token=True if model_args.use_auth_token else None,
|
|
)
|
|
|
|
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
|
|
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
|
|
'[OPTIONS]','[HINT]']}
|
|
|
|
tokenizer.add_tokens(tokens_to_add)
|
|
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
|
added_tokens = tokenizer.get_added_vocab()
|
|
|
|
model = PromptT5.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
|
config=config,
|
|
cache_dir=model_args.cache_dir,
|
|
revision=model_args.model_revision,
|
|
use_auth_token=True if model_args.use_auth_token else None,
|
|
)
|
|
# model.copy_encoder()
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
|
|
global max_source_length
|
|
max_source_length = 1024
|
|
|
|
|
|
# Set decoder_start_token_id
|
|
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
|
if isinstance(tokenizer, MBartTokenizer):
|
|
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
|
|
else:
|
|
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
|
|
|
|
if model.config.decoder_start_token_id is None:
|
|
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
|
|
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
|
|
|
|
|
if training_args.local_rank == -1 or training_args.no_cuda:
|
|
device = torch.device("cuda")
|
|
n_gpu = torch.cuda.device_count()
|
|
|
|
|
|
|
|
|
|
|
|
# Temporarily set max_target_length for training.
|
|
max_target_length = data_args.max_target_length
|
|
padding = "max_length" if data_args.pad_to_max_length else False
|
|
|
|
|
|
|
|
|
|
question_column = data_args.question_column
|
|
context_column = data_args.context_column
|
|
answer_column = data_args.answer_column
|
|
# import random
|
|
|
|
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
|
|
# Data collator
|
|
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
|
if data_args.pad_to_max_length:
|
|
data_collator = default_data_collator
|
|
else:
|
|
data_collator = DataCollatorForSeq2Seq(
|
|
tokenizer,
|
|
model=model,
|
|
label_pad_token_id=label_pad_token_id,
|
|
pad_to_multiple_of=8 if training_args.fp16 else None,
|
|
)
|
|
|
|
#start
|
|
train_dataloaders = {}
|
|
eval_dataloaders = {}
|
|
replay_dataloaders = {}
|
|
to_be_train = ["squad1_1","squad2","narrativeqa_dev","boolq","arc_hard","arc_easy","openbookqa","race_string","mctest_corrected_the_separator","newsqa","quoref","ropes","drop","natural_questions_with_dpr_para","commonsenseqa","qasc","physical_iqa","social_iqa","winogrande_xl","multirc","boolq_np"]
|
|
|
|
to_be_eval = "onestopqa_advanced,onestopqa_elementry,onestopqa_intermediate,prost_multiple_choice_with_no_context,dream,processbank_test,cosmosqa,mcscript,mcscript2,quail,reclor,measuring_massive_multitask_language_understanding,head_qa_en_test,race_c,pubmedqa_pqal_short_ans,strategyqa,tweetqa,qaconv,record_extractive,adversarialqa_dbert_dev,adversarialqa_dbidaf_dev,adversarialqa_droberta_dev".split(",")
|
|
|
|
max_length = (
|
|
training_args.generation_max_length
|
|
if training_args.generation_max_length is not None
|
|
else data_args.val_max_target_length
|
|
)
|
|
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
|
|
|
|
|
|
|
if training_args.local_rank!=-1:
|
|
world_size = torch.distributed.get_world_size()
|
|
else:
|
|
world_size = 1
|
|
device_id = training_args.local_rank if training_args.local_rank!=-1 else 0
|
|
|
|
|
|
no_decay = ['bias', 'LayerNorm.weight']
|
|
optimizer_grouped_parameters = [
|
|
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
|
|
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
|
]
|
|
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-4,eps=1e-8)
|
|
|
|
|
|
|
|
# retriever_data =
|
|
tensorizer, bi_encoder, _ = init_biencoder_components(
|
|
"hf_bert", {}, inference_only=True #hf_bert
|
|
)
|
|
c_encoder = BertEncoder_For_CrossEncoder.from_pretrained(
|
|
"bert-base-uncased"
|
|
)
|
|
|
|
|
|
|
|
if training_args.local_rank == -1:
|
|
bi_encoder = bi_encoder.to(torch.device("cuda", device_id))
|
|
c_encoder = c_encoder.to(torch.device("cuda", device_id))
|
|
else:
|
|
bi_encoder = bi_encoder.to(torch.device("cuda", device_id))
|
|
c_encoder = c_encoder.to(torch.device("cuda", device_id))
|
|
|
|
bi_encoder = nn.parallel.DistributedDataParallel(
|
|
bi_encoder,
|
|
device_ids=[training_args.local_rank] if training_args.n_gpu != 0 else None,
|
|
output_device=training_args.local_rank if training_args.n_gpu else None,
|
|
find_unused_parameters=True,
|
|
)
|
|
c_encoder = nn.parallel.DistributedDataParallel(
|
|
c_encoder,
|
|
device_ids=[training_args.local_rank] if training_args.n_gpu != 0 else None,
|
|
output_device=training_args.local_rank if training_args.n_gpu else None,
|
|
)
|
|
global all_cards
|
|
|
|
if training_args.local_rank==-1:
|
|
all_cards = [-1]
|
|
else:
|
|
all_cards = list(range(world_size))
|
|
if training_args.local_rank<=0:
|
|
print("all_cards:",all_cards)
|
|
|
|
|
|
|
|
|
|
global format_name
|
|
global priority_level
|
|
priority_level = {}
|
|
device = training_args.local_rank
|
|
if device<0:
|
|
device = 0
|
|
|
|
with torch.no_grad():
|
|
epoch_id = 1
|
|
|
|
map_location = torch.device("cuda",device)
|
|
dev_format2size = {"bool":3,"multichoice":3,"extractive":7,"abstractive":6}
|
|
test_format2size = {"bool":1,"multichoice":7,"extractive":3,"abstractive":1}
|
|
tlm = get_model_obj(c_encoder)
|
|
tlm.load_state_dict(torch.load("./rersave/cencoder-{}.pt".format(str(epoch_id)),map_location=map_location))
|
|
|
|
global format_hints_ids
|
|
global format_hints_ids_test
|
|
|
|
format_hints_ids = {}
|
|
format_hints_ids_test = {}
|
|
skip_selection = True
|
|
|
|
for format_name in ["bool","extractive","abstractive","multichoice"]:
|
|
if skip_selection:
|
|
break
|
|
start_line = 0
|
|
end_line = 0
|
|
to_load_model = get_model_obj(bi_encoder)
|
|
|
|
to_load_model.load_state_dict(torch.load("./retsave/biencoder-{}.pt".format(str(epoch_id)),map_location=map_location))
|
|
|
|
|
|
for sub_set in range(test_format2size[format_name]):
|
|
rr_dev_dataset = load_from_disk("./sel_file/{}test-retrak{}.hf".format(format_name,str(sub_set)))
|
|
|
|
|
|
|
|
with amp.autocast(enabled=True):
|
|
# trainbc(bi_encoder,c_encoder,rr_train_dataset,training_args,forward_train,qa2ret,qa2rer)
|
|
filthints(bi_encoder,c_encoder,rr_dev_dataset,training_args,False,False,False,"test")
|
|
|
|
|
|
if skip_selection==False:
|
|
fout = open("./mem_scores/format_hintstest-{}{}.json".format(str(training_args.local_rank),str(epoch_id)),'w')
|
|
json.dump(format_hints_ids_test,fout)
|
|
fout.close()
|
|
|
|
if training_args.local_rank!=-1:
|
|
torch.distributed.barrier()
|
|
format_hints_ids = load_hints_test(all_cards,epoch_id)
|
|
|
|
for format_name in ["bool","extractive","abstractive","multichoice"]:
|
|
break
|
|
generate_samples_unseen(format_name,format_hints_ids,epoch_id)
|
|
|
|
|
|
if training_args.local_rank!=-1:
|
|
torch.distributed.barrier()
|
|
eval_ds = {}
|
|
eval_exp = {}
|
|
|
|
|
|
|
|
model = PromptT5.from_pretrained("./epoch_ckpt{}".format(str(epoch_id)))
|
|
eval_dataloaders={}
|
|
for item in to_be_eval:
|
|
eval_dataloaders[item]=(load_from_disk("./final_eval/{}-eval.hf".format(item)),load_from_disk("./processed/{}-evalex.hf".format(item)))
|
|
model.set_test()
|
|
for item in to_be_eval:
|
|
print("current_test:",item)
|
|
eval_dataset,eval_examples = eval_dataloaders[item]
|
|
len_ds = len(eval_dataset)
|
|
range_ds = list(range(len_ds))
|
|
eval_dataset = eval_dataset.add_column("id",range_ds).add_column("example_id",range_ds)
|
|
|
|
# assert False
|
|
metric = load_metric(dataset_name_to_metric[item])
|
|
data_collator = DataCollatorForSeq2Seq(
|
|
tokenizer,
|
|
model=model,
|
|
label_pad_token_id=label_pad_token_id,
|
|
pad_to_multiple_of=8 if training_args.fp16 else None,
|
|
)
|
|
tester = QuestionAnsweringTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=None,
|
|
eval_dataset=eval_dataset,
|
|
eval_examples=eval_examples,
|
|
answer_column_name=answer_column,
|
|
dataset_name=item,
|
|
tokenizer=tokenizer,
|
|
data_collator=data_collator,
|
|
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
|
)
|
|
|
|
|
|
|
|
|
|
max_length, num_beams, ignore_keys_for_eval = None, None, None
|
|
metrics = tester.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,
|
|
metric_key_prefix="eval")
|
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
|
|
|
tester.log_metrics("eval", metrics)
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
# -
|
|
|
|
def _mp_fn(index):
|
|
# For xla_spawn (TPUs)
|
|
main()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|