add: diana

This commit is contained in:
出蛰 2023-05-18 20:12:38 +08:00
parent f570084ce5
commit 99c2ba1195
59 changed files with 27198 additions and 0 deletions

69
diana/data_process/QAInput.py Executable file
View File

@ -0,0 +1,69 @@
#ProQA
#[paper]https://arxiv.org/abs/2205.04040
class StructuralQAInput:
@classmethod
def qg_input_abstrativeqa(cls, context, question, options=None):
question = question
source_text = f'[TASK] [ABSTRACTIVE] [QUESTION] {question}. [CONTEXT] {context} the answer is: '
return source_text
@classmethod
def qg_input_boolqa(cls, context, question, options=None):
source_text = f'[TASK] [BOOL] [QUESTION] {question}. [CONTEXT] {context} the answer is: '
return source_text
@classmethod
def qg_input_extractive_qa(cls, context, question, options=None):
source_text = f'[TASK] [EXTRACTIVE] [QUESTION] {question.lower().capitalize()}. [CONTEXT] {context} the answer is: '
return source_text
@classmethod
def qg_input_multirc(cls, context, question, options=None):
source_text = f'[TASK] [MultiChoice] [QUESTION] {question} [OPTIONS] {options} [CONTEXT] {context} the answer is: '
return source_text
class LFQAInput:
@classmethod
def qg_input_abstrativeqa(cls, context, question, options=None):
question = question
source_text = f'{question} \\n {context}'
return source_text
@classmethod
def qg_input_boolqa(cls, context, question, options=None):
source_text = f'{question} \\n {context}'
return source_text
@classmethod
def qg_input_extractive_qa(cls, context, question, options=None):
source_text = f'{question.lower().capitalize()} \\n {context}'
return source_text
@classmethod
def qg_input_multirc(cls, context, question, options=None):
source_text = f'{question} \\n {options} \\n {context}'
return source_text
class SimpleQAInput:
@classmethod
def qg_input_abstrativeqa(cls, context, question, options=None):
question = question
source_text = f'{question} \\n {context}'
return source_text
@classmethod
def qg_input_boolqa(cls, context, question, options=None):
source_text = f'{question} \\n {context}'
return source_text
@classmethod
def qg_input_extractive_qa(cls, context, question, options=None):
source_text = f'{question.lower().capitalize()} \\n {context}'
return source_text
@classmethod
def qg_input_multirc(cls, context, question, options=None):
source_text = f'{question} \\n {options} \\n {context}'
return source_text

88
diana/data_process/QGInput.py Executable file
View File

@ -0,0 +1,88 @@
class QGInput:
@classmethod
def qg_input_abstractiveqa(cls,context,question,answers,options=None):
outputs = []
for ans in answers:
source_text = f'The context: {context}, the answer is: {ans}. This is a summary task, please generate question: '
target_text = f'{question}'
outputs.append({'source_text':source_text,'target_text':target_text})
return outputs
@classmethod
def qg_input_abstractiveqa_qapairs(cls,context,question,answers,options=None):
outputs = []
for ans in answers:
source_text = f'Generate question and the answer: The context: {context}.'
target_text = f'[Question]: {question} [Answer] {ans}'
outputs.append({'source_text': source_text, 'target_text': target_text})
return outputs
@classmethod
def qg_input_boolqa(cls,context,question,answers,options=None):
outputs = []
for ans in answers:
source_text = f'The context: {context}. The answer is: <hl> {ans} <hl>. This is a bool task, generate question: '
target_text = f'{question}'
outputs.append({'source_text':source_text,'target_text':target_text})
return outputs
@classmethod
def qg_input_extractive_qa(cls, context, question, answers,options=None):
outputs = []
for ans in answers:
context = context.replace(ans,f'<hl> {ans} <hl>')
source_text = f'The context: {context}. The answer is: {ans}. This is a extractive task, generate question: '
target_text = f'{question}'
outputs.append({'source_text': source_text, 'target_text': target_text})
return outputs
@classmethod
def qg_input_extractive_qapairs(cls, context, question, answers, options=None):
outputs = []
for ans in answers:
source_text = f'Generate question and the answer: [Context] {context}.'
target_text = f'[question] {question} [answer] {ans}'
outputs.append({'source_text': source_text, 'target_text': target_text})
return outputs
@classmethod
def qg_input_multirc(cls, context, question, answers, options=None):
outputs = []
for ans in answers:
source_text = f'The context: {context}. The options are: {options}. The answer is: <hl> {ans} <hl>. This is a machine reading comprehension task, generate question: '
target_text = f'{question}'
outputs.append({'source_text': source_text, 'target_text': target_text})
return outputs
@classmethod
def qg_intput_multirc_negoption(cls, context, question, answers, options=None):
outputs = []
for ans in answers:
for option in options:
if option!=ans:
source_text = f'Context: {context}. Answer is: <hl> {ans} <hl>. Generate false option: '
target_text = f'{option}'
outputs.append({'source_text': source_text, 'target_text': target_text})
return outputs
@classmethod
def qg_intput_multirc_qapairs(cls, context, question, answers, options=None):
outputs = []
for ans in answers:
source_text = f'Generate question and answer: [Context] {context}'
target_text = f'[Question] {question} [Answer] {ans}'
outputs.append({'source_text': source_text, 'target_text': target_text})
return outputs
@classmethod
def qg_intput_multirc_negoption_withq(cls, context, question, answers, options=None):
outputs = []
for ans in answers:
for option in options:
if option != ans:
source_text = f'Generate false option: [Context] {context}. [Question] {question}. [Answer] <hl> {ans} <hl>.'
target_text = f'{option}'
outputs.append({'source_text': source_text, 'target_text': target_text})
return outputs
@classmethod
def qg_intput_multirc_qfirst(cls, context, question, answers, options=None):
outputs = []
for ans in answers:
source_text = f'Generate question: Context: {context}. Answer: <hl> {ans} <hl>. '
target_text = f'{question}'
outputs.append({'source_text': source_text, 'target_text': target_text})
return outputs

156
diana/data_process/tasks.py Executable file
View File

@ -0,0 +1,156 @@
import logging
import t5
import os
import json
import functools
import tensorflow as tf
import tensorflow_datasets as tfds
DATASETS = [
"ai2_science_middle",
"ai2_science_elementary",
"arc_hard",
"arc_easy",
"mctest",
"mctest_corrected_the_separator",
"natural_questions",
"quoref",
"squad1_1",
"squad2",
"boolq",
"multirc",
"newsqa",
"race_string",
"ropes",
"drop",
"narrativeqa",
"openbookqa",
"qasc",
"boolq_np",
"contrast_sets_boolq",
"contrast_sets_drop",
"contrast_sets_quoref",
"contrast_sets_ropes",
"commonsenseqa",
"qasc_with_ir",
"openbookqa_with_ir",
"arc_easy_with_ir",
"arc_hard_with_ir",
"ambigqa",
"natural_questions_direct_ans",
"natural_questions_with_dpr_para",
"winogrande_xs",
"winogrande_s",
"winogrande_m",
"winogrande_l",
"winogrande_xl",
"social_iqa",
"physical_iqa",
]
DATA_DIR = f"gs://unifiedqa/data/"
def dataset_preprocessor(ds):
def normalize_text(text):
"""Lowercase and remove quotes from a TensorFlow string."""
text = tf.strings.lower(text)
text = tf.strings.regex_replace(text, "'(.*)'", r"\1")
return text
def to_inputs_and_targets(ex):
return {
"inputs": normalize_text(ex["inputs"]),
"targets": normalize_text(ex["targets"])
}
return ds.map(to_inputs_and_targets,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
def get_path(data_dir1, split):
tsv_path = {
"train": os.path.join(data_dir1, "train.tsv"),
"dev": os.path.join(data_dir1, "dev.tsv"),
"test": os.path.join(data_dir1, "test.tsv")
}
return tsv_path[split]
def dataset_fn(split, shuffle_files=False, dataset=""):
# We only have one file for each split.
del shuffle_files
# Load lines from the text file as examples.
ds = tf.data.TextLineDataset(get_path(DATA_DIR + dataset, split))
# Split each "<question>\t<answer>" example into (question, answer) tuple.
print(" >>>> about to read csv . . . ")
ds = ds.map(
functools.partial(tf.io.decode_csv, record_defaults=["", ""],
field_delim="\t", use_quote_delim=False),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
# print(" >>>> after reading csv . . . ")
# Map each tuple to a {"question": ... "answer": ...} dict.
ds = ds.map(lambda *ex: dict(zip(["inputs", "targets"], ex)))
# print(" >>>> after mapping . . . ")
return ds
for dataset in DATASETS:
print(f" >>>> reading dataset: {dataset}")
t5.data.set_tfds_data_dir_override(DATA_DIR + dataset)
t5.data.TaskRegistry.add(
f"{dataset}_task",
# Supply a function which returns a tf.data.Dataset.
dataset_fn=functools.partial(dataset_fn, dataset=dataset),
splits=["train", "dev", "test"],
# Supply a function which preprocesses text from the tf.data.Dataset.
text_preprocessor=[dataset_preprocessor],
# Use the same vocabulary that we used for pre-training.
sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
# Lowercase targets before computing metrics.
postprocess_fn=t5.data.postprocessors.lower_text,
metric_fns=[t5.evaluation.metrics.accuracy],
)
print(f" >>>> adding one mixture per dataset: `{dataset}_mixture`")
t5.data.MixtureRegistry.add(
f"{dataset}_mixture", [f"{dataset}_task"], default_rate=1.0
)
# dataset-pair mixtures
for dataset1 in DATASETS:
for dataset2 in DATASETS:
if dataset1 == dataset2:
continue
print(f" >>>> adding one mixture for dataset-pair: `{dataset1}_{dataset2}_mixture`")
t5.data.MixtureRegistry.add(
f"{dataset1}_{dataset2}_mixture",
[f"{dataset1}_task", f"{dataset2}_task"],
default_rate=1.0
)
# union model: used for training UnifiedQA
union_datasets = [
"narrativeqa",
"ai2_science_middle", "ai2_science_elementary",
"arc_hard", "arc_easy",
"mctest_corrected_the_separator",
"squad1_1", "squad2",
"boolq",
"race_string",
"openbookqa",
]
print(f" >>>> adding one mixture for `union_mixture`")
t5.data.MixtureRegistry.add(
f"union_mixture",
[f"{d}_task" for d in union_datasets],
default_rate=1.0
)
# leave-one-out
for dd in union_datasets:
filtered_datasets = [f"{d}_task" for d in union_datasets if d != dd]
assert len(filtered_datasets) < len(union_datasets)
t5.data.MixtureRegistry.add(
f"union_minus_{dd}_mixture",
filtered_datasets,
default_rate=1.0
)

View File

@ -0,0 +1,782 @@
#!/usr/bin/env python
# coding=utf-8
import logging
import os
import torch
import copy,random
import sys
import json
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple
# from data import QAData
sys.path.append("../")
from models.nopt5 import T5ForConditionalGeneration as PromptT5
# from models.promptT5 import PromptT5
from downstream.dataset_processors import *
from downstream.trainer import QuestionAnsweringTrainer
import datasets
import numpy as np
from datasets import load_dataset, load_metric,load_from_disk
import os
from functools import partial
os.environ["WANDB_DISABLED"] = "true"
import transformers
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
M2M100Tokenizer,
MBart50Tokenizer,
EvalPrediction,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
set_seed,
)
from pathlib import Path
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
# check_min_version("4.12.0.dev0")
logger = logging.getLogger(__name__)
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
format2dataset = {
'extractive':['squad','extractive','newsqa','quoref'], #,'ropes'
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
}
seed_datasets = ['race','narrativeqa','squad']
dataset2format= {}
task2id = {}
for k,vs in format2dataset.items():
for v in vs:
dataset2format[v] = k
task2format={}
for k,v in dataset2format.items():
if v=="extractive":
task2format[k]=0
elif v=="abstractive":
task2format[k]=1
else:
task2format[k]=2
# task2id[v] = len(task2id.keys())
# task2id = {v:k for k,v in enumerate(task_list)}
task2id = {'squad': 0, 'extractive': 1, 'narrativeqa': 2, 'abstractive': 3, 'race': 4, 'multichoice': 5, 'boolq': 6, 'bool': 7, 'newsqa':8,'quoref':9,'ropes':10,'drop':11,'nqopen':12,'boolq_np':13,'openbookqa':14,'mctest':15,'social_iqa':16,'dream':17}
# task2id = {'squad': 0, 'extractive': 0, 'narrativeqa': 1, 'abstractive':1 , 'race': 2, 'multichoice': 2, 'boolq': 3, 'bool': 3, 'newsqa':8,'quoref':9,'ropes':10,'drop':11,'nqopen':12,'boolq_np':13,'openbookqa':14,'mctest':15,'social_iqa':16,'dream':17}
# print(task2id)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
fix_t5: Optional[bool] = field(
default=False,
metadata={"help": "whether to fix the main parameters of t5"}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
validation_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
"a jsonlines file."
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
after_train: Optional[str] = field(default=None)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
},
)
max_seq_length: int = field(
default=384,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
"which is used during ``evaluate`` and ``predict``."
},
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
)
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
"needs to be the target language token.(Usually it is the target language token)"
},
)
do_lowercase: bool = field(
default=False,
metadata={
'help':'Whether to process input into lowercase'
}
)
append_another_bos: bool = field(
default=False,
metadata={
'help':'Whether to add another bos'
}
)
context_column: Optional[str] = field(
default="context",
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
)
question_column: Optional[str] = field(
default="question",
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
answer_column: Optional[str] = field(
default="answers",
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
max_task_num: Optional[int] = field(
default=30,
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
qa_task_type_num: Optional[int] = field(
default=4,
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
prompt_number: Optional[int] = field(
default=40,
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
add_task_prompt: Optional[bool] = field(
default=False,
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
reload_from_trained_prompt: Optional[bool] = field(
default=False,
metadata={"help": "whether to reload prompt from trained prompt"},
)
trained_prompt_path:Optional[str] = field(
default = None,
metadata={
'help':'the path storing trained prompt embedding'
}
)
load_from_format_task_id: Optional[bool] = field(
default=False,
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
# elif self.source_lang is None or self.target_lang is None:
# raise ValueError("Need to specify the source language and the target language.")
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
# assert extension == "json", "`train_file` should be a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
# assert extension == "json", "`validation_file` should be a json file."
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
# +
def main():
def preprocess_valid_function(examples):
preprocess_fn = preprocess_proqa#dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, "question","context","answer")
model_inputs = tokenizer(inputs, max_length=1024, padding=False, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=128, padding=False, truncation=True)
model_inputs["example_id"] = []
for i in range(len(model_inputs["input_ids"])):
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = i #sample_mapping[i]
model_inputs["example_id"].append(examples["id"][sample_index])
gen_prompt_ids = [-(i+1) for i in range(1520,1520+20)]
format_id = task2format[dataset_name]
format_prompt_id_start = 500
format_prompt_ids = [-(i+1) for i in range(format_prompt_id_start + format_id * 40,
format_prompt_id_start + (format_id + 1) * 40)]
task_id = task2id[dataset_name]
task_prompt_id_start = 0
task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * 40,
task_prompt_id_start + (task_id + 1) * 40)]
domain_prompt_id_start = 800
domain_prompt_number = 20
domain_prompt_ids = [- (i + 1) for i in range(domain_prompt_id_start,
domain_prompt_id_start + 20)]*5
input_ids = copy.deepcopy(
[gen_prompt_ids+format_prompt_ids+task_prompt_ids + domain_prompt_ids+input_ids for input_ids in model_inputs['input_ids']])
model_inputs['input_ids'] = input_ids # [format_prompt_ids+input_ids for input_ids in model_inputs['input_ids']]
model_inputs['attention_mask'] = [[1] * 200 + attention_mask for attention_mask in
model_inputs['attention_mask']]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
def save_prompt_embedding(model,path):
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
prompt_path = os.path.join(path,'prompt_embedding_info')
torch.save(save_prompt_info,prompt_path)
logger.info(f'Saving prompt embedding information to {prompt_path}')
def preprocess_function(examples):
preprocess_fn = preprocess_proqa#dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, "question","context","answer")
model_inputs = tokenizer(inputs, max_length=1024, padding=False, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=128, padding=False, truncation=True)
model_inputs["labels"] = labels["input_ids"]
gen_prompt_ids = [-(i+1) for i in range(1520,1520+20)]
format_id = task2format[dataset_name]
format_prompt_id_start = 500
format_prompt_ids = [-(i+1) for i in range(format_prompt_id_start + format_id * 40,
format_prompt_id_start + (format_id + 1) * 40)]
task_id = task2id[dataset_name]
task_prompt_id_start = 0
task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * 40,
task_prompt_id_start + (task_id + 1) * 40)]
domain_prompt_id_start = 800
domain_prompt_number = 20
domain_prompt_ids = [- (i + 1) for i in range(domain_prompt_id_start,
domain_prompt_id_start + 20)]*5
input_ids = copy.deepcopy(
[gen_prompt_ids+format_prompt_ids+task_prompt_ids + domain_prompt_ids+input_ids for input_ids in model_inputs['input_ids']])
model_inputs['input_ids'] = input_ids # [format_prompt_ids+input_ids for input_ids in model_inputs['input_ids']]
model_inputs['attention_mask'] = [[1] * 200 + attention_mask for attention_mask in
model_inputs['attention_mask']]
return model_inputs
def dataset_name_to_func(dataset_name):
mapping = {
'squad': preprocess_sqaud_batch,
'squad_v2': preprocess_sqaud_batch,
'boolq': preprocess_boolq_batch,
'narrativeqa': preprocess_narrativeqa_batch,
'race': preprocess_race_batch,
'newsqa': preprocess_newsqa_batch,
'quoref': preprocess_sqaud_batch,
'ropes': preprocess_ropes_batch,
'drop': preprocess_drop_batch,
'nqopen': preprocess_sqaud_abstractive_batch,
# 'multirc': preprocess_boolq_batch,
'boolq_np': preprocess_boolq_batch,
'openbookqa': preprocess_openbookqa_batch,
'mctest': preprocess_race_batch,
'social_iqa': preprocess_social_iqa_batch,
'dream': preprocess_dream_batch,
}
return mapping[dataset_name]
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if 'same' in model_args.model_name_or_path:
task2id = {'squad': 0, 'extractive': 0, 'narrativeqa': 1, 'abstractive': 1, 'race': 2, 'multichoice': 2,
'boolq': 3, 'bool': 3, 'newsqa': 8, 'quoref': 9, 'ropes': 10, 'drop': 11, 'nqopen': 12,
'boolq_np': 13, 'openbookqa': 14, 'mctest': 15, 'social_iqa': 16, 'dream': 17}
else:
task2id = {'squad': 0, 'extractive': 1, 'narrativeqa': 2, 'abstractive': 3, 'race': 4, 'multichoice': 5,
'boolq': 6, 'bool': 7, 'newsqa': 8, 'quoref': 9, 'ropes': 10, 'drop': 11, 'nqopen': 12,
'boolq_np': 13, 'openbookqa': 14, 'mctest': 15, 'social_iqa': 16, 'dream': 17}
dataset_name_to_metric = {
'squad': 'squad',
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
'boolq': 'accuracy',
'narrativeqa': 'rouge',
'race': 'accuracy',
'quoref': 'squad',
'ropes': 'squad',
'drop': 'squad',
'nqopen': 'squad',
# 'multirc': 'accuracy',
'boolq_np': 'accuracy',
'openbookqa': 'accuracy',
'mctest': 'accuracy',
'social_iqa': 'accuracy',
'dream': 'accuracy',
}
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
if data_args.source_prefix is None and model_args.model_name_or_path in [
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
]:
logger.warning(
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
"`--source_prefix 'translate English to German: ' `"
)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(training_args.seed)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
# '[OPTIONS]'])
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
'[OPTIONS]']}
tokenizer.add_tokens(tokens_to_add)
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
added_tokens = tokenizer.get_added_vocab()
logger.info('Added tokens: {}'.format(added_tokens))
model = PromptT5.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
# task_num = data_args.max_task_num,
# prompt_num = data_args.prompt_number,
# format_num = data_args.qa_task_type_num,
# add_task_prompt = False
)
model.resize_token_embeddings(len(tokenizer))
#reload format specific task-prompt for newly involved task
#format_prompts###task_promptsf
data_args.reload_from_trained_prompt = False#@
data_args.load_from_format_task_id = False#@
### before pretrain come !!!!!!
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
task_start_id = data_args.prompt_number * len(format2dataset.keys())
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
# print(dataset2format[data_args.dataset_name])
# print(data_args.dataset_name)
elif data_args.reload_from_trained_prompt:
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
prompt_info = torch.load(data_args.trained_prompt_path)
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
task_start_id = data_args.prompt_number * len(format2dataset.keys())
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
format_id = format2id[dataset2format[data_args.dataset_name]]
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
logger.info(
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
# Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
if training_args.local_rank == -1 or training_args.no_cuda:
device = torch.device("cuda")
n_gpu = torch.cuda.device_count()
# Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length
padding = "max_length" if data_args.pad_to_max_length else False
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
logger.warning(
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
)
question_column = data_args.question_column
context_column = data_args.context_column
answer_column = data_args.answer_column
# import random
if data_args.max_source_length > tokenizer.model_max_length:
logger.warning(
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
# Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
if data_args.pad_to_max_length:
data_collator = default_data_collator
else:
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
)
#start
train_dataloaders = {}
eval_dataloaders = {}
replay_dataloaders = {}
for ds_name in task2id.keys():
if (not ds_name in ["extractive","abstractive","multichoice","bool","boolq","boolq_np","ropes"]):
# data_args.dataset_name = cur_dataset
cur_dataset = ds_name
data_args.dataset_name = cur_dataset
if True:
# Downloading and loading a dataset from the hub.
if not data_args.dataset_name in ['newsqa', 'nqopen', 'mctest', 'social_iqa']:
if data_args.dataset_name == "race":
data_args.dataset_config_name = "all"
elif data_args.dataset_name == "openbookqa":
data_args.dataset_config_name = "main"
elif data_args.dataset_name == "dream":
data_args.dataset_config_name = "plain_text"
else:
data_args.dataset_config_name = "plain_text"
raw_datasets = load_dataset(
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
)
if data_args.dataset_name in ['ropes']:
# add answer_start (not used for squad evaluation but required)
def add_answer_start(example):
example['answers'].update({"answer_start": [0]})
return example
raw_datasets = raw_datasets.map(add_answer_start)
elif data_args.dataset_name in ['drop']:
# add answer_start (not used for squad evaluation but required)
# add answers (for squad evaluation)
def add_answers(example):
answers = []
answer_start = []
for _a in example['answers_spans']['spans']:
answers.append(_a)
answer_start.append(-1)
example['answers'] = {"text": answers, "answer_start": answer_start}
return example
raw_datasets = raw_datasets.map(add_answers)
column_names = raw_datasets["validation"].column_names
else:
data_files = {}
basic_file = "../data_process/data/"+data_args.dataset_name+"/"
data_files["train"] = basic_file+"train.json"
# extension = data_args.train_file.split(".")[-1]
data_files["validation"] = basic_file+"dev.json"
# extension = data_args.validation_file.split(".")[-1]
if data_args.dataset_name in ['newsqa', 'nqopen', 'multirc', 'boolq_np', 'mctest', 'social_iqa']:
raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir)
else:
print(f"Unknown dataset {data_args.dataset_name}")
raise NotImplementedError
column_names = raw_datasets["validation"].column_names
metric = load_metric(dataset_name_to_metric[data_args.dataset_name])
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"]
if True:
all_num = list(range(0, len(train_dataset)))
random.shuffle(all_num)
selected_indices = all_num[:100]
replay_dataset = train_dataset.select(selected_indices)
# train_dataset = train_dataset.select(range(data_args.max_train_samples))
with training_args.main_process_first(desc="train dataset map pre-processing"):
replay_dataset = replay_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=True,
desc="Running tokenizer on replay dataset",
)
replay_dataloaders[ds_name] = replay_dataset
# replay_dataset = load_from_disk("./processed/{}-replay.hf".format(ds_name))
# print(replay_dataset)
with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=True,
desc="Running tokenizer on train dataset",
)
# print(train_dataset)
train_dataloaders[ds_name] = train_dataset
train_dataset.save_to_disk("./ours/{}-train.hf".format(ds_name))
# train_dataset = load_from_disk("./processed/{}-train.hf".format(ds_name))
max_target_length = data_args.val_max_target_length
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_examples = raw_datasets["validation"]
def add_id(example,index):
example.update({'id':index})
return example
if 'id' not in eval_examples.features.keys():
eval_examples = eval_examples.map(add_id,with_indices=True)
if data_args.max_eval_samples is not None:
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_examples.map(
preprocess_validation_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=True,
desc="Running tokenizer on validation dataset",
)
eval_dataloaders[ds_name] = (eval_dataset,eval_examples)
eval_dataset.save_to_disk("./ours/{}-eval.hf".format(ds_name))
eval_examples.save_to_disk("./ours/{}-evalex.hf".format(ds_name))
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
if len(languages) > 0:
kwargs["language"] = languages
return None
# -
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,192 @@
import sys
sys.path.append('../data_process')
from typing import List, Optional, Tuple
from QAInput import StructuralQAInput as QAInput
#from QAInput import QAInputNoPrompt as QAInput
def preprocess_sqaud_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_sqaud_abstractive_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_boolq_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
question_column, context_column, answer_column = 'question', 'passage', 'answer'
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
targets = [str(ans) for ans in answers] #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
# print(inputs,targets)
return inputs, targets
def preprocess_boolq_batch_pretrain(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0].capitalize() if len(answer["text"]) > 0 else "" for answer in answers]
# print(inputs,targets)
return inputs, targets
def preprocess_narrativeqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = [exp['summary']['text'] for exp in examples['document']]
questions = [exp['text'] for exp in examples['question']]
answers = [ans[0]['text'] for ans in examples['answers']]
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = answers #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_narrativeqa_batch_pretrain(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_drop_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = examples['passage']
questions = examples['question']
answers = examples['answers']
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_race_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = examples['article']
questions = examples['question']
all_options = examples['options']
answers = examples['answer']
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}; D. {options[3]}' for options in all_options]
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3 }
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
return inputs, targets
def preprocess_newsqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
# inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_ropes_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
backgrounds = examples["background"]
situations = examples["situation"]
answers = examples[answer_column]
inputs = [QAInput.qg_input_extractive_qa(" ".join([background, situation]), question) for question, background, situation in zip(questions, backgrounds, situations)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_openbookqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples['question_stem']
all_options = examples['choices']
answers = examples['answerKey']
options_texts = [f"options: A. {options['text'][0]}; B. {options['text'][1]}; C. {options['text'][2]}; D. {options['text'][3]}" for options in all_options]
inputs = [QAInput.qg_input_multirc("", question, ops) for question, ops in zip(questions, options_texts)]
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
targets = [options['text'][ans_map[answer]] for options, answer in zip(all_options, answers)]
return inputs, targets
def preprocess_social_iqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = examples['article']
questions = examples['question']
all_options = examples['options']
answers = examples['answer']
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
ans_map = {'A': 0, 'B': 1, 'C': 2,}
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
return inputs, targets
def preprocess_dream_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = [" ".join(dialogue) for dialogue in examples['dialogue']]
questions = examples['question']
all_options = examples['choice']
answers = examples['answer']
answer_idxs = [options.index(answer) for answer, options in zip(answers, all_options)]
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
targets = answers
return inputs, targets

View File

@ -0,0 +1,53 @@
function fulldata_hfdata() {
SAVING_PATH=$1
PRETRAIN_MODEL_PATH=$2
TRAIN_BATCH_SIZE=$3
EVAL_BATCH_SIZE=$4
mkdir -p ${SAVING_PATH}
python -u diana.py \
--model_name_or_path ${PRETRAIN_MODEL_PATH} \
--output_dir ${SAVING_PATH} \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size ${TRAIN_BATCH_SIZE} \
--per_device_eval_batch_size ${EVAL_BATCH_SIZE} \
--overwrite_output_dir \
--gradient_accumulation_steps 2 \
--num_train_epochs 5 \
--warmup_ratio 0.1 \
--logging_steps 5 \
--learning_rate 1e-4 \
--predict_with_generate \
--num_beams 4 \
--save_strategy no \
--evaluation_strategy no \
--weight_decay 1e-2 \
--max_source_length 512 \
--label_smoothing_factor 0.1 \
--do_lowercase True \
--load_best_model_at_end True \
--greater_is_better True \
--save_total_limit 10 \
--ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log
}
SAVING_PATH=$1
TRAIN_BATCH_SIZE=$2
EVAL_BATCH_SIZE=$3
MODE=try
SAVING_PATH=${SAVING_PATH}/lifelong/${MODE}
fulldata_hfdata ${SAVING_PATH} t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE}

827
diana/downstream/diana.py Normal file
View File

@ -0,0 +1,827 @@
#!/usr/bin/env python
# coding=utf-8
import logging
import os
os.environ['NCCL_ASYNC_ERROR_HANDLING']='1'
import torch
import copy,random
import sys
import json
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple
from sklearn.cluster import KMeans
# from data import QAData
sys.path.append("../")
from models.metat5 import T5ForConditionalGeneration as PromptT5
from downstream.dataset_processors import *
from downstream.l2ptrainer import QuestionAnsweringTrainer
import datasets
import numpy as np
from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets
import os
from functools import partial
os.environ["WANDB_DISABLED"] = "true"
import transformers
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
M2M100Tokenizer,
MBart50Tokenizer,
EvalPrediction,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
set_seed,
)
from pathlib import Path
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
format2dataset = {
'extractive':['squad','extractive','newsqa','quoref'],
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
}
seed_datasets = ['race','narrativeqa','squad']
dataset2format= {}
for k,vs in format2dataset.items():
for v in vs:
dataset2format[v] = k
task2id = {'squad': 0, 'narrativeqa': 1, 'race': 2, 'newsqa':3,'quoref':4,'drop':5,'nqopen':6,'openbookqa':7,'mctest':8,'social_iqa':9,'dream':10}
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
fix_t5: Optional[bool] = field(
default=False,
metadata={"help": "whether to fix the main parameters of t5"}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
validation_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
"a jsonlines file."
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
after_train: Optional[str] = field(default=None)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
},
)
max_seq_length: int = field(
default=384,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
"which is used during ``evaluate`` and ``predict``."
},
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
)
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
"needs to be the target language token.(Usually it is the target language token)"
},
)
do_lowercase: bool = field(
default=False,
metadata={
'help':'Whether to process input into lowercase'
}
)
append_another_bos: bool = field(
default=False,
metadata={
'help':'Whether to add another bos'
}
)
context_column: Optional[str] = field(
default="context",
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
)
question_column: Optional[str] = field(
default="question",
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
answer_column: Optional[str] = field(
default="answers",
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
max_task_num: Optional[int] = field(
default=30,
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
qa_task_type_num: Optional[int] = field(
default=4,
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
prompt_number: Optional[int] = field(
default=100,
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
add_task_prompt: Optional[bool] = field(
default=False,
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
reload_from_trained_prompt: Optional[bool] = field(
default=False,
metadata={"help": "whether to reload prompt from trained prompt"},
)
trained_prompt_path:Optional[str] = field(
default = None,
metadata={
'help':'the path storing trained prompt embedding'
}
)
load_from_format_task_id: Optional[bool] = field(
default=False,
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
# elif self.source_lang is None or self.target_lang is None:
# raise ValueError("Need to specify the source language and the target language.")
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
# assert extension == "json", "`train_file` should be a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
# assert extension == "json", "`validation_file` should be a json file."
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
# +
def main():
def preprocess_validation_function(examples):
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
model_inputs["example_id"] = []
for i in range(len(model_inputs["input_ids"])):
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = i #sample_mapping[i]
model_inputs["example_id"].append(examples["id"][sample_index])
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
if True:
pass
# logger.info(f'Loading task {data_args.dataset_name} prompt')
# format_id = format2id[dataset2format[data_args.dataset_name]]
# task_id = task2id[data_args.dataset_name]
# format_prompt_ids = [- (i + 1) for i in range(format_id * data_args.prompt_number, (
# format_id + 1) * data_args.prompt_number)] # list(range(-(format_id * data_args.prompt_number+1), -((format_id + 1) * data_args.prompt_number+1)))
# task_prompt_id_start = len(format2id.keys()) * data_args.prompt_number
# logger.info('Prompt ids {}: {}'.format(task_prompt_id_start + task_id * data_args.prompt_number,
# task_prompt_id_start + (task_id + 1) * data_args.prompt_number))
# task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * data_args.prompt_number,
# task_prompt_id_start + (task_id + 1) * data_args.prompt_number)]
# input_ids = copy.deepcopy(
# [format_prompt_ids + task_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
#input_ids = copy.deepcopy([format_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
# model_inputs['input_ids'] = input_ids
# model_inputs['attention_mask'] = [[1] * data_args.prompt_number * 2 + attention_mask for attention_mask in
# model_inputs['attention_mask']]
# input_ids = copy.deepcopy([input])
return model_inputs
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
def save_prompt_embedding(model,path):
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
prompt_path = os.path.join(path,'prompt_embedding_info')
torch.save(save_prompt_info,prompt_path)
logger.info(f'Saving prompt embedding information to {prompt_path}')
def preprocess_function(examples):
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def save_load_diverse_sample(model,trainset):
with torch.no_grad():
device = 'cuda:0'
search_upbound = len(trainset)//4
query_idxs = [None]*30
keys = model.encoder.domain_keys
for idx,item in enumerate(trainset.select(range(search_upbound))):
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
return_dict=True)
result = torch.matmul(query,keys.t())
result = torch.topk(result,5).indices[0].cpu().numpy().tolist()
key_sel = None
for key_idx in result:
if query_idxs[key_idx] is None or len(query_idxs[key_idx])<3:
key_sel = key_idx
break
if key_sel is not None:
if query_idxs[key_sel] is None:
query_idxs[key_sel] = [idx]
else:
query_idxs[key_sel].append(idx)
total_idxs = []
for item in query_idxs:
try:
total_idxs.extend(item[:3])
except:
total_idxs.extend(random.sample(list(range(search_upbound,len(trainset))),3))
total_idxs = list(set(total_idxs))
total_idxs = random.sample(total_idxs,50)
sub_set = trainset.select(total_idxs)
features = []
for idx,item in enumerate(sub_set):
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
return_dict=True)
features.append(query.detach().cpu().numpy())
return sub_set,features
def dataset_name_to_func(dataset_name):
mapping = {
'squad': preprocess_sqaud_batch,
'squad_v2': preprocess_sqaud_batch,
'boolq': preprocess_boolq_batch,
'narrativeqa': preprocess_narrativeqa_batch,
'race': preprocess_race_batch,
'newsqa': preprocess_newsqa_batch,
'quoref': preprocess_sqaud_batch,
'ropes': preprocess_ropes_batch,
'drop': preprocess_drop_batch,
'nqopen': preprocess_sqaud_abstractive_batch,
# 'multirc': preprocess_boolq_batch,
'boolq_np': preprocess_boolq_batch,
'openbookqa': preprocess_openbookqa_batch,
'mctest': preprocess_race_batch,
'social_iqa': preprocess_social_iqa_batch,
'dream': preprocess_dream_batch,
}
return mapping[dataset_name]
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
dataset_name_to_metric = {
'squad': 'metric/squad_v1_local/squad_v1_local.py',
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
'boolq': 'metric/accuracy.py',
'narrativeqa': 'metric/rouge_local/rouge_metric.py',
'race': 'metric/accuracy.py',
'quoref': 'metric/squad_v1_local/squad_v1_local.py',
'ropes': 'metric/squad_v1_local/squad_v1_local.py',
'drop': 'metric/squad_v1_local/squad_v1_local.py',
'nqopen': 'metric/squad_v1_local/squad_v1_local.py',
'boolq_np': 'metric/accuracy.py',
'openbookqa': 'metric/accuracy.py',
'mctest': 'metric/accuracy.py',
'social_iqa': 'metric/accuracy.py',
'dream': 'metric/accuracy.py',
}
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
if data_args.source_prefix is None and model_args.model_name_or_path in [
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
]:
logger.warning(
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
"`--source_prefix 'translate English to German: ' `"
)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(training_args.seed)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
# '[OPTIONS]'])
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
'[OPTIONS]']}
tokenizer.add_tokens(tokens_to_add)
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
added_tokens = tokenizer.get_added_vocab()
logger.info('Added tokens: {}'.format(added_tokens))
model = PromptT5.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
# task_num = data_args.max_task_num,
# prompt_num = data_args.prompt_number,
# format_num = data_args.qa_task_type_num,
# add_task_prompt = False
)
model.resize_token_embeddings(len(tokenizer))
#reload format specific task-prompt for newly involved task
#format_prompts###task_promptsf
data_args.reload_from_trained_prompt = False#@
data_args.load_from_format_task_id = False#@
### before pretrain come !!!!!!
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
task_start_id = data_args.prompt_number * len(format2dataset.keys())
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
# print(dataset2format[data_args.dataset_name])
# print(data_args.dataset_name)
elif data_args.reload_from_trained_prompt:
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
prompt_info = torch.load(data_args.trained_prompt_path)
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
task_start_id = data_args.prompt_number * len(format2dataset.keys())
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
format_id = format2id[dataset2format[data_args.dataset_name]]
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
logger.info(
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
# Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
if training_args.local_rank == -1 or training_args.no_cuda:
device = torch.device("cuda")
n_gpu = torch.cuda.device_count()
# Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length
padding = "max_length" if data_args.pad_to_max_length else False
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
logger.warning(
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
)
question_column = data_args.question_column
context_column = data_args.context_column
answer_column = data_args.answer_column
# import random
if data_args.max_source_length > tokenizer.model_max_length:
logger.warning(
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
# Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
if data_args.pad_to_max_length:
data_collator = default_data_collator
else:
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
)
#start
train_dataloaders = {}
eval_dataloaders = {}
replay_dataloaders = {}
all_replay = None
for ds_name in ['squad','narrativeqa','race','newsqa','quoref','drop','nqopen','openbookqa','mctest','social_iqa','dream']:
eval_dataloaders[ds_name] = (load_from_disk("./oursfinallong/{}-eval.hf".format(ds_name)),load_from_disk("./oursfinallong/{}-evalex.hf".format(ds_name)))
pre_tasks = []
pre_general = []
pre_test = []
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
task_sequence = ['squad','newsqa','narrativeqa','nqopen','race','openbookqa','mctest','social_iqa']
# task_sequence = ["woz.en","srl","sst","wikisql","squad"]
fileout = open("diana_log.txt",'w')
all_replay = None
all_features = []
all_ids = []
cluster_num=0
for cur_dataset in task_sequence:
p=200
if p>len(load_from_disk("./oursfinallong/{}-train.hf".format(cur_dataset))):
p = len(load_from_disk("./oursfinallong/{}-train.hf".format(cur_dataset)))
trainds = load_from_disk("./oursfinallong/{}-train.hf".format(cur_dataset))
cluster_num+=5
pre_tasks.append(cur_dataset)
if cur_dataset==task_sequence[-1]:
pre_tasks.extend(["drop","quoref","dream"])
data_args.dataset_name = cur_dataset
logger.info("current_dataset:"+cur_dataset)
training_args.do_train = True
training_args.to_eval = False
metric = load_metric(dataset_name_to_metric[cur_dataset])
if all_replay is not None:
fused = datasets.concatenate_datasets([all_replay,trainds])
else:
fused = trainds
training_args.num_train_epochs = 5
model.encoder.reset_train_count()
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=fused,
eval_dataset=None,
eval_examples=None,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
train_result = trainer.train()
if training_args.local_rank<=0:
save_set,features = save_load_diverse_sample(model,trainds)
if all_replay is None:
all_replay = save_set
else:
all_replay = datasets.concatenate_datasets([all_replay,save_set])
if all_features==[]:
all_features=features
else:
all_features.extend(features)
np.save("./all_features.npy",np.array(all_features))
all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset))
if training_args.local_rank!=-1:
torch.distributed.barrier()
all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset))
all_ids.extend([task2id[cur_dataset]]*50)
all_features=np.load("./all_features.npy").tolist()
model.encoder.reset_train_count()
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=all_replay,
eval_dataset=None,
eval_examples=None,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
train_result = trainer.train()
model.encoder.add_negs(all_ids,all_features)
for pre_dataset in pre_tasks:
data_args.dataset_name = pre_dataset
metric = load_metric(dataset_name_to_metric[pre_dataset])
eval_dataset,eval_examples = eval_dataloaders[pre_dataset]
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=None,
eval_dataset=eval_dataset,
eval_examples=eval_examples,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
torch.cuda.empty_cache()
logger.info("*** Evaluate:{} ***".format(data_args.dataset_name))
max_length, num_beams, ignore_keys_for_eval = None, None, None
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.local_rank<=0:
try:
print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout)
print(metrics,file=fileout)
except:
pass
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
if len(languages) > 0:
kwargs["language"] = languages
return None
# -
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

53
diana/downstream/diana.sh Normal file
View File

@ -0,0 +1,53 @@
function fulldata_hfdata() {
SAVING_PATH=$1
PRETRAIN_MODEL_PATH=$2
TRAIN_BATCH_SIZE=$3
EVAL_BATCH_SIZE=$4
mkdir -p ${SAVING_PATH}
python -m torch.distributed.launch --nproc_per_node=4 diana.py \
--model_name_or_path ${PRETRAIN_MODEL_PATH} \
--output_dir ${SAVING_PATH} \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size ${TRAIN_BATCH_SIZE} \
--per_device_eval_batch_size ${EVAL_BATCH_SIZE} \
--overwrite_output_dir \
--gradient_accumulation_steps 2 \
--num_train_epochs 5 \
--warmup_ratio 0.1 \
--logging_steps 5 \
--learning_rate 1e-4 \
--predict_with_generate \
--num_beams 4 \
--save_strategy no \
--evaluation_strategy no \
--weight_decay 1e-2 \
--max_source_length 512 \
--label_smoothing_factor 0.1 \
--do_lowercase True \
--load_best_model_at_end True \
--greater_is_better True \
--save_total_limit 10 \
--ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log
}
SAVING_PATH=$1
TRAIN_BATCH_SIZE=$2
EVAL_BATCH_SIZE=$3
MODE=try
SAVING_PATH=${SAVING_PATH}/lifelong/${MODE}
fulldata_hfdata ${SAVING_PATH} ./t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE}

View File

@ -0,0 +1,798 @@
#!/usr/bin/env python
# coding=utf-8
import logging
import os
os.environ['NCCL_ASYNC_ERROR_HANDLING']='1'
import torch
import copy,random
import sys
import json
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple
from sklearn.cluster import KMeans
# from data import QAData
sys.path.append("../")
from models.metat5nometa import T5ForConditionalGeneration as PromptT5
from downstream.dataset_processors import *
from downstream.l2ptrainer import QuestionAnsweringTrainer
import datasets
import numpy as np
from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets
import os
from functools import partial
os.environ["WANDB_DISABLED"] = "true"
import transformers
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
M2M100Tokenizer,
MBart50Tokenizer,
EvalPrediction,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
set_seed,
)
from pathlib import Path
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
format2dataset = {
'extractive':['squad','extractive','newsqa','quoref'],
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
}
seed_datasets = ['race','narrativeqa','squad']
dataset2format= {}
for k,vs in format2dataset.items():
for v in vs:
dataset2format[v] = k
task2id = {'squad': 0, 'narrativeqa': 1, 'race': 2, 'newsqa':3,'quoref':4,'drop':5,'nqopen':6,'openbookqa':7,'mctest':8,'social_iqa':9,'dream':10}
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
fix_t5: Optional[bool] = field(
default=False,
metadata={"help": "whether to fix the main parameters of t5"}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
validation_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
"a jsonlines file."
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
after_train: Optional[str] = field(default=None)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
},
)
max_seq_length: int = field(
default=384,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
"which is used during ``evaluate`` and ``predict``."
},
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
)
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
"needs to be the target language token.(Usually it is the target language token)"
},
)
do_lowercase: bool = field(
default=False,
metadata={
'help':'Whether to process input into lowercase'
}
)
append_another_bos: bool = field(
default=False,
metadata={
'help':'Whether to add another bos'
}
)
context_column: Optional[str] = field(
default="context",
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
)
question_column: Optional[str] = field(
default="question",
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
answer_column: Optional[str] = field(
default="answers",
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
max_task_num: Optional[int] = field(
default=30,
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
qa_task_type_num: Optional[int] = field(
default=4,
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
prompt_number: Optional[int] = field(
default=100,
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
add_task_prompt: Optional[bool] = field(
default=False,
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
reload_from_trained_prompt: Optional[bool] = field(
default=False,
metadata={"help": "whether to reload prompt from trained prompt"},
)
trained_prompt_path:Optional[str] = field(
default = None,
metadata={
'help':'the path storing trained prompt embedding'
}
)
load_from_format_task_id: Optional[bool] = field(
default=False,
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
# elif self.source_lang is None or self.target_lang is None:
# raise ValueError("Need to specify the source language and the target language.")
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
# assert extension == "json", "`train_file` should be a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
# assert extension == "json", "`validation_file` should be a json file."
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
# +
def main():
def preprocess_validation_function(examples):
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
model_inputs["example_id"] = []
for i in range(len(model_inputs["input_ids"])):
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = i #sample_mapping[i]
model_inputs["example_id"].append(examples["id"][sample_index])
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
if True:
pass
# logger.info(f'Loading task {data_args.dataset_name} prompt')
# format_id = format2id[dataset2format[data_args.dataset_name]]
# task_id = task2id[data_args.dataset_name]
# format_prompt_ids = [- (i + 1) for i in range(format_id * data_args.prompt_number, (
# format_id + 1) * data_args.prompt_number)] # list(range(-(format_id * data_args.prompt_number+1), -((format_id + 1) * data_args.prompt_number+1)))
# task_prompt_id_start = len(format2id.keys()) * data_args.prompt_number
# logger.info('Prompt ids {}: {}'.format(task_prompt_id_start + task_id * data_args.prompt_number,
# task_prompt_id_start + (task_id + 1) * data_args.prompt_number))
# task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * data_args.prompt_number,
# task_prompt_id_start + (task_id + 1) * data_args.prompt_number)]
# input_ids = copy.deepcopy(
# [format_prompt_ids + task_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
#input_ids = copy.deepcopy([format_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
# model_inputs['input_ids'] = input_ids
# model_inputs['attention_mask'] = [[1] * data_args.prompt_number * 2 + attention_mask for attention_mask in
# model_inputs['attention_mask']]
# input_ids = copy.deepcopy([input])
return model_inputs
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
def save_prompt_embedding(model,path):
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
prompt_path = os.path.join(path,'prompt_embedding_info')
torch.save(save_prompt_info,prompt_path)
logger.info(f'Saving prompt embedding information to {prompt_path}')
def preprocess_function(examples):
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def save_load_diverse_sample(model,trainset):
features = []
device='cuda:0'
with torch.no_grad():
sub_set = trainset.select(range(50))
for idx,item in enumerate(sub_set):
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
return_dict=True)
features.append(query.detach().cpu().numpy())
return sub_set,features
def dataset_name_to_func(dataset_name):
mapping = {
'squad': preprocess_sqaud_batch,
'squad_v2': preprocess_sqaud_batch,
'boolq': preprocess_boolq_batch,
'narrativeqa': preprocess_narrativeqa_batch,
'race': preprocess_race_batch,
'newsqa': preprocess_newsqa_batch,
'quoref': preprocess_sqaud_batch,
'ropes': preprocess_ropes_batch,
'drop': preprocess_drop_batch,
'nqopen': preprocess_sqaud_abstractive_batch,
# 'multirc': preprocess_boolq_batch,
'boolq_np': preprocess_boolq_batch,
'openbookqa': preprocess_openbookqa_batch,
'mctest': preprocess_race_batch,
'social_iqa': preprocess_social_iqa_batch,
'dream': preprocess_dream_batch,
}
return mapping[dataset_name]
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
dataset_name_to_metric = {
'squad': 'metric/squad_v1_local/squad_v1_local.py',
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
'boolq': 'metric/accuracy.py',
'narrativeqa': 'metric/rouge_local/rouge_metric.py',
'race': 'metric/accuracy.py',
'quoref': 'metric/squad_v1_local/squad_v1_local.py',
'ropes': 'metric/squad_v1_local/squad_v1_local.py',
'drop': 'metric/squad_v1_local/squad_v1_local.py',
'nqopen': 'metric/squad_v1_local/squad_v1_local.py',
'boolq_np': 'metric/accuracy.py',
'openbookqa': 'metric/accuracy.py',
'mctest': 'metric/accuracy.py',
'social_iqa': 'metric/accuracy.py',
'dream': 'metric/accuracy.py',
}
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
if data_args.source_prefix is None and model_args.model_name_or_path in [
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
]:
logger.warning(
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
"`--source_prefix 'translate English to German: ' `"
)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(training_args.seed)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
# '[OPTIONS]'])
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
'[OPTIONS]']}
tokenizer.add_tokens(tokens_to_add)
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
added_tokens = tokenizer.get_added_vocab()
logger.info('Added tokens: {}'.format(added_tokens))
model = PromptT5.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
# task_num = data_args.max_task_num,
# prompt_num = data_args.prompt_number,
# format_num = data_args.qa_task_type_num,
# add_task_prompt = False
)
model.resize_token_embeddings(len(tokenizer))
#reload format specific task-prompt for newly involved task
#format_prompts###task_promptsf
data_args.reload_from_trained_prompt = False#@
data_args.load_from_format_task_id = False#@
### before pretrain come !!!!!!
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
task_start_id = data_args.prompt_number * len(format2dataset.keys())
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
# print(dataset2format[data_args.dataset_name])
# print(data_args.dataset_name)
elif data_args.reload_from_trained_prompt:
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
prompt_info = torch.load(data_args.trained_prompt_path)
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
task_start_id = data_args.prompt_number * len(format2dataset.keys())
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
format_id = format2id[dataset2format[data_args.dataset_name]]
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
logger.info(
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
# Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
if training_args.local_rank == -1 or training_args.no_cuda:
device = torch.device("cuda")
n_gpu = torch.cuda.device_count()
# Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length
padding = "max_length" if data_args.pad_to_max_length else False
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
logger.warning(
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
)
question_column = data_args.question_column
context_column = data_args.context_column
answer_column = data_args.answer_column
# import random
if data_args.max_source_length > tokenizer.model_max_length:
logger.warning(
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
# Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
if data_args.pad_to_max_length:
data_collator = default_data_collator
else:
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
)
#start
train_dataloaders = {}
eval_dataloaders = {}
replay_dataloaders = {}
all_replay = None
for ds_name in ['squad','narrativeqa','race','newsqa','quoref','drop','nqopen','openbookqa','mctest','social_iqa','dream']:
eval_dataloaders[ds_name] = (load_from_disk("./oursnometa/{}-eval.hf".format(ds_name)),load_from_disk("./oursnometa/{}-evalex.hf".format(ds_name)))
pre_tasks = []
pre_general = []
pre_test = []
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
task_sequence = ['squad','newsqa','narrativeqa','nqopen','race','openbookqa','mctest','social_iqa']
fileout = open("diana_log.txt",'w')
all_replay = None
all_features = []
all_ids = []
cluster_num=0
for cur_dataset in task_sequence:
trainds = load_from_disk("./oursnometa/{}-train.hf".format(cur_dataset)).select(range(1000))
cluster_num+=5
pre_tasks.append(cur_dataset)
if cur_dataset==task_sequence[-1]:
pre_tasks.extend(["drop","quoref","dream"])
data_args.dataset_name = cur_dataset
logger.info("current_dataset:"+cur_dataset)
training_args.do_train = True
training_args.to_eval = False
metric = load_metric(dataset_name_to_metric[cur_dataset])
if all_replay is not None:
fused = datasets.concatenate_datasets([all_replay,trainds])
else:
fused = trainds
training_args.num_train_epochs = 5
model.encoder.reset_train_count()
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=fused,
eval_dataset=None,
eval_examples=None,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
train_result = trainer.train()
if training_args.local_rank<=0:
try:
save_set,features = save_load_diverse_sample(model,trainds)
if all_replay is None:
all_replay = save_set
else:
all_replay = datasets.concatenate_datasets([all_replay,save_set])
if all_features==[]:
all_features=features
else:
all_features.extend(features)
np.save("./all_features.npy",np.array(all_features))
all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset))
except:
pass
if training_args.local_rank!=-1:
torch.distributed.barrier()
all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset))
all_ids.extend([task2id[cur_dataset]]*50)
all_features=np.load("./all_features.npy").tolist()
model.encoder.reset_train_count()
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=all_replay,
eval_dataset=None,
eval_examples=None,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
train_result = trainer.train()
model.encoder.add_negs(all_ids,all_features)
for pre_dataset in pre_tasks:
data_args.dataset_name = pre_dataset
metric = load_metric(dataset_name_to_metric[pre_dataset])
eval_dataset,eval_examples = eval_dataloaders[pre_dataset]
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=None,
eval_dataset=eval_dataset,
eval_examples=eval_examples,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
torch.cuda.empty_cache()
logger.info("*** Evaluate:{} ***".format(data_args.dataset_name))
max_length, num_beams, ignore_keys_for_eval = None, None, None
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.local_rank<=0:
try:
print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout)
print(metrics,file=fileout)
except:
pass
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
if len(languages) > 0:
kwargs["language"] = languages
return None
# -
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,53 @@
function fulldata_hfdata() {
SAVING_PATH=$1
PRETRAIN_MODEL_PATH=$2
TRAIN_BATCH_SIZE=$3
EVAL_BATCH_SIZE=$4
mkdir -p ${SAVING_PATH}
python -m torch.distributed.launch --nproc_per_node=4 dianawometa.py \
--model_name_or_path ${PRETRAIN_MODEL_PATH} \
--output_dir ${SAVING_PATH} \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size ${TRAIN_BATCH_SIZE} \
--per_device_eval_batch_size ${EVAL_BATCH_SIZE} \
--overwrite_output_dir \
--gradient_accumulation_steps 2 \
--num_train_epochs 5 \
--warmup_ratio 0.1 \
--logging_steps 5 \
--learning_rate 1e-4 \
--predict_with_generate \
--num_beams 4 \
--save_strategy no \
--evaluation_strategy no \
--weight_decay 1e-2 \
--max_source_length 512 \
--label_smoothing_factor 0.1 \
--do_lowercase True \
--load_best_model_at_end True \
--greater_is_better True \
--save_total_limit 10 \
--ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log
}
SAVING_PATH=$1
TRAIN_BATCH_SIZE=$2
EVAL_BATCH_SIZE=$3
MODE=try
SAVING_PATH=${SAVING_PATH}/lifelong/${MODE}
fulldata_hfdata ${SAVING_PATH} ./t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE}

View File

@ -0,0 +1,827 @@
#!/usr/bin/env python
# coding=utf-8
import logging
import os
os.environ['NCCL_ASYNC_ERROR_HANDLING']='1'
import torch
import copy,random
import sys
import json
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple
from sklearn.cluster import KMeans
# from data import QAData
sys.path.append("../")
from models.metat5notask import T5ForConditionalGeneration as PromptT5
from downstream.dataset_processors import *
from downstream.l2ptrainer import QuestionAnsweringTrainer
import datasets
import numpy as np
from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets
import os
from functools import partial
os.environ["WANDB_DISABLED"] = "true"
import transformers
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
M2M100Tokenizer,
MBart50Tokenizer,
EvalPrediction,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
set_seed,
)
from pathlib import Path
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
format2dataset = {
'extractive':['squad','extractive','newsqa','quoref'],
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
}
seed_datasets = ['race','narrativeqa','squad']
dataset2format= {}
for k,vs in format2dataset.items():
for v in vs:
dataset2format[v] = k
task2id = {'squad': 0, 'narrativeqa': 1, 'race': 2, 'newsqa':3,'quoref':4,'drop':5,'nqopen':6,'openbookqa':7,'mctest':8,'social_iqa':9,'dream':10}
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
fix_t5: Optional[bool] = field(
default=False,
metadata={"help": "whether to fix the main parameters of t5"}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
validation_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
"a jsonlines file."
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
after_train: Optional[str] = field(default=None)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
},
)
max_seq_length: int = field(
default=384,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
"which is used during ``evaluate`` and ``predict``."
},
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
)
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
"needs to be the target language token.(Usually it is the target language token)"
},
)
do_lowercase: bool = field(
default=False,
metadata={
'help':'Whether to process input into lowercase'
}
)
append_another_bos: bool = field(
default=False,
metadata={
'help':'Whether to add another bos'
}
)
context_column: Optional[str] = field(
default="context",
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
)
question_column: Optional[str] = field(
default="question",
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
answer_column: Optional[str] = field(
default="answers",
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
max_task_num: Optional[int] = field(
default=30,
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
qa_task_type_num: Optional[int] = field(
default=4,
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
prompt_number: Optional[int] = field(
default=100,
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
add_task_prompt: Optional[bool] = field(
default=False,
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
reload_from_trained_prompt: Optional[bool] = field(
default=False,
metadata={"help": "whether to reload prompt from trained prompt"},
)
trained_prompt_path:Optional[str] = field(
default = None,
metadata={
'help':'the path storing trained prompt embedding'
}
)
load_from_format_task_id: Optional[bool] = field(
default=False,
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
# elif self.source_lang is None or self.target_lang is None:
# raise ValueError("Need to specify the source language and the target language.")
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
# assert extension == "json", "`train_file` should be a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
# assert extension == "json", "`validation_file` should be a json file."
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
# +
def main():
def preprocess_validation_function(examples):
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
model_inputs["example_id"] = []
for i in range(len(model_inputs["input_ids"])):
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = i #sample_mapping[i]
model_inputs["example_id"].append(examples["id"][sample_index])
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
if True:
pass
# logger.info(f'Loading task {data_args.dataset_name} prompt')
# format_id = format2id[dataset2format[data_args.dataset_name]]
# task_id = task2id[data_args.dataset_name]
# format_prompt_ids = [- (i + 1) for i in range(format_id * data_args.prompt_number, (
# format_id + 1) * data_args.prompt_number)] # list(range(-(format_id * data_args.prompt_number+1), -((format_id + 1) * data_args.prompt_number+1)))
# task_prompt_id_start = len(format2id.keys()) * data_args.prompt_number
# logger.info('Prompt ids {}: {}'.format(task_prompt_id_start + task_id * data_args.prompt_number,
# task_prompt_id_start + (task_id + 1) * data_args.prompt_number))
# task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * data_args.prompt_number,
# task_prompt_id_start + (task_id + 1) * data_args.prompt_number)]
# input_ids = copy.deepcopy(
# [format_prompt_ids + task_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
#input_ids = copy.deepcopy([format_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
# model_inputs['input_ids'] = input_ids
# model_inputs['attention_mask'] = [[1] * data_args.prompt_number * 2 + attention_mask for attention_mask in
# model_inputs['attention_mask']]
# input_ids = copy.deepcopy([input])
return model_inputs
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
def save_prompt_embedding(model,path):
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
prompt_path = os.path.join(path,'prompt_embedding_info')
torch.save(save_prompt_info,prompt_path)
logger.info(f'Saving prompt embedding information to {prompt_path}')
def preprocess_function(examples):
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def save_load_diverse_sample(model,trainset):
with torch.no_grad():
device = 'cuda:0'
search_upbound = len(trainset)//4
query_idxs = [None]*30
keys = model.encoder.domain_keys
for idx,item in enumerate(trainset.select(range(search_upbound))):
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
return_dict=True)
result = torch.matmul(query,keys.t())
result = torch.topk(result,5).indices[0].cpu().numpy().tolist()
key_sel = None
for key_idx in result:
if query_idxs[key_idx] is None or len(query_idxs[key_idx])<3:
key_sel = key_idx
break
if key_sel is not None:
if query_idxs[key_sel] is None:
query_idxs[key_sel] = [idx]
else:
query_idxs[key_sel].append(idx)
total_idxs = []
for item in query_idxs:
try:
total_idxs.extend(item[:3])
except:
total_idxs.extend(random.sample(list(range(search_upbound,len(trainset))),3))
total_idxs = list(set(total_idxs))
total_idxs = random.sample(total_idxs,50)
sub_set = trainset.select(total_idxs)
features = []
for idx,item in enumerate(sub_set):
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
return_dict=True)
features.append(query.detach().cpu().numpy())
return sub_set,features
def dataset_name_to_func(dataset_name):
mapping = {
'squad': preprocess_sqaud_batch,
'squad_v2': preprocess_sqaud_batch,
'boolq': preprocess_boolq_batch,
'narrativeqa': preprocess_narrativeqa_batch,
'race': preprocess_race_batch,
'newsqa': preprocess_newsqa_batch,
'quoref': preprocess_sqaud_batch,
'ropes': preprocess_ropes_batch,
'drop': preprocess_drop_batch,
'nqopen': preprocess_sqaud_abstractive_batch,
# 'multirc': preprocess_boolq_batch,
'boolq_np': preprocess_boolq_batch,
'openbookqa': preprocess_openbookqa_batch,
'mctest': preprocess_race_batch,
'social_iqa': preprocess_social_iqa_batch,
'dream': preprocess_dream_batch,
}
return mapping[dataset_name]
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
dataset_name_to_metric = {
'squad': 'metric/squad_v1_local/squad_v1_local.py',
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
'boolq': 'metric/accuracy.py',
'narrativeqa': 'metric/rouge_local/rouge_metric.py',
'race': 'metric/accuracy.py',
'quoref': 'metric/squad_v1_local/squad_v1_local.py',
'ropes': 'metric/squad_v1_local/squad_v1_local.py',
'drop': 'metric/squad_v1_local/squad_v1_local.py',
'nqopen': 'metric/squad_v1_local/squad_v1_local.py',
'boolq_np': 'metric/accuracy.py',
'openbookqa': 'metric/accuracy.py',
'mctest': 'metric/accuracy.py',
'social_iqa': 'metric/accuracy.py',
'dream': 'metric/accuracy.py',
}
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
if data_args.source_prefix is None and model_args.model_name_or_path in [
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
]:
logger.warning(
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
"`--source_prefix 'translate English to German: ' `"
)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(training_args.seed)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
# '[OPTIONS]'])
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
'[OPTIONS]']}
tokenizer.add_tokens(tokens_to_add)
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
added_tokens = tokenizer.get_added_vocab()
logger.info('Added tokens: {}'.format(added_tokens))
model = PromptT5.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
# task_num = data_args.max_task_num,
# prompt_num = data_args.prompt_number,
# format_num = data_args.qa_task_type_num,
# add_task_prompt = False
)
model.resize_token_embeddings(len(tokenizer))
#reload format specific task-prompt for newly involved task
#format_prompts###task_promptsf
data_args.reload_from_trained_prompt = False#@
data_args.load_from_format_task_id = False#@
### before pretrain come !!!!!!
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
task_start_id = data_args.prompt_number * len(format2dataset.keys())
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
# print(dataset2format[data_args.dataset_name])
# print(data_args.dataset_name)
elif data_args.reload_from_trained_prompt:
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
prompt_info = torch.load(data_args.trained_prompt_path)
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
task_start_id = data_args.prompt_number * len(format2dataset.keys())
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
format_id = format2id[dataset2format[data_args.dataset_name]]
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
logger.info(
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
# Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
if training_args.local_rank == -1 or training_args.no_cuda:
device = torch.device("cuda")
n_gpu = torch.cuda.device_count()
# Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length
padding = "max_length" if data_args.pad_to_max_length else False
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
logger.warning(
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
)
question_column = data_args.question_column
context_column = data_args.context_column
answer_column = data_args.answer_column
# import random
if data_args.max_source_length > tokenizer.model_max_length:
logger.warning(
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
# Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
if data_args.pad_to_max_length:
data_collator = default_data_collator
else:
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
)
#start
train_dataloaders = {}
eval_dataloaders = {}
replay_dataloaders = {}
all_replay = None
for ds_name in ['squad','narrativeqa','race','newsqa','quoref','drop','nqopen','openbookqa','mctest','social_iqa','dream']:
eval_dataloaders[ds_name] = (load_from_disk("./oursnotask/{}-eval.hf".format(ds_name)),load_from_disk("./oursnotask/{}-evalex.hf".format(ds_name)))
pre_tasks = []
pre_general = []
pre_test = []
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
task_sequence = ['squad','newsqa','narrativeqa','nqopen','race','openbookqa','mctest','social_iqa']
# task_sequence = ["woz.en","srl","sst","wikisql","squad"]
fileout = open("diana_log.txt",'w')
all_replay = None
all_features = []
all_ids = []
cluster_num=0
for cur_dataset in task_sequence:
p=200
if p>len(load_from_disk("./oursnotask/{}-train.hf".format(cur_dataset))):
p = len(load_from_disk("./oursnotask/{}-train.hf".format(cur_dataset)))
trainds = load_from_disk("./oursnotask/{}-train.hf".format(cur_dataset))
cluster_num+=5
pre_tasks.append(cur_dataset)
if cur_dataset==task_sequence[-1]:
pre_tasks.extend(["drop","quoref","dream"])
data_args.dataset_name = cur_dataset
logger.info("current_dataset:"+cur_dataset)
training_args.do_train = True
training_args.to_eval = False
metric = load_metric(dataset_name_to_metric[cur_dataset])
if all_replay is not None:
fused = datasets.concatenate_datasets([all_replay,trainds])
else:
fused = trainds
training_args.num_train_epochs = 5
model.encoder.reset_train_count()
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=fused,
eval_dataset=None,
eval_examples=None,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
train_result = trainer.train()
if training_args.local_rank<=0:
save_set,features = save_load_diverse_sample(model,trainds)
if all_replay is None:
all_replay = save_set
else:
all_replay = datasets.concatenate_datasets([all_replay,save_set])
if all_features==[]:
all_features=features
else:
all_features.extend(features)
np.save("./all_features.npy",np.array(all_features))
all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset))
if training_args.local_rank!=-1:
torch.distributed.barrier()
all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset))
all_ids.extend([task2id[cur_dataset]]*50)
all_features=np.load("./all_features.npy").tolist()
model.encoder.reset_train_count()
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=all_replay,
eval_dataset=None,
eval_examples=None,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
train_result = trainer.train()
model.encoder.add_negs(all_ids,all_features)
for pre_dataset in pre_tasks:
data_args.dataset_name = pre_dataset
metric = load_metric(dataset_name_to_metric[pre_dataset])
eval_dataset,eval_examples = eval_dataloaders[pre_dataset]
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=None,
eval_dataset=eval_dataset,
eval_examples=eval_examples,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
torch.cuda.empty_cache()
logger.info("*** Evaluate:{} ***".format(data_args.dataset_name))
max_length, num_beams, ignore_keys_for_eval = None, None, None
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.local_rank<=0:
try:
print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout)
print(metrics,file=fileout)
except:
pass
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
if len(languages) > 0:
kwargs["language"] = languages
return None
# -
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,53 @@
function fulldata_hfdata() {
SAVING_PATH=$1
PRETRAIN_MODEL_PATH=$2
TRAIN_BATCH_SIZE=$3
EVAL_BATCH_SIZE=$4
mkdir -p ${SAVING_PATH}
python -m torch.distributed.launch --nproc_per_node=4 dianawotask.py \
--model_name_or_path ${PRETRAIN_MODEL_PATH} \
--output_dir ${SAVING_PATH} \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size ${TRAIN_BATCH_SIZE} \
--per_device_eval_batch_size ${EVAL_BATCH_SIZE} \
--overwrite_output_dir \
--gradient_accumulation_steps 2 \
--num_train_epochs 5 \
--warmup_ratio 0.1 \
--logging_steps 5 \
--learning_rate 1e-4 \
--predict_with_generate \
--num_beams 4 \
--save_strategy no \
--evaluation_strategy no \
--weight_decay 1e-2 \
--max_source_length 512 \
--label_smoothing_factor 0.1 \
--do_lowercase True \
--load_best_model_at_end True \
--greater_is_better True \
--save_total_limit 10 \
--ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log
}
SAVING_PATH=$1
TRAIN_BATCH_SIZE=$2
EVAL_BATCH_SIZE=$3
MODE=try
SAVING_PATH=${SAVING_PATH}/lifelong/${MODE}
fulldata_hfdata ${SAVING_PATH} ./t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE}

556
diana/downstream/l2ptrainer.py Executable file
View File

@ -0,0 +1,556 @@
from transformers import Seq2SeqTrainer, is_torch_tpu_available, EvalPrediction
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import nltk
import datasets
import re
import os
import numpy as np
import torch
import random
from pathlib import Path
import nltk
# nltk.download('punkt')
from transformers.trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
EvalLoopOutput,
EvalPrediction,
HPSearchBackend,
HubStrategy,
IntervalStrategy,
PredictionOutput,
ShardedDDPOption,
TrainerMemoryTracker,
TrainOutput,
default_compute_objective,
default_hp_space,
denumpify_detensorize,
get_last_checkpoint,
number_of_arguments,
set_seed,
speed_metrics,
)
import warnings
from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
ShardSampler,
distributed_broadcast_scalars,
distributed_concat,
find_batch_size,
get_parameter_names,
nested_concat,
nested_detach,
nested_numpify,
nested_truncate,
nested_xla_mesh_reduce,
reissue_pt_warnings,
)
from transformers.file_utils import (
CONFIG_NAME,
WEIGHTS_NAME,
get_full_repo_name,
is_apex_available,
is_datasets_available,
is_in_notebook,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
)
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
from transformers.trainer_utils import PredictionOutput,EvalLoopOutput
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
def fix_buggy_characters(str):
return re.sub("[{}^\\\\`\u2047<]", " ", str)
def replace_punctuation(str):
return str.replace("\"", "").replace("'", "")
def score_string_similarity(str1, str2):
if str1 == str2:
return 3.0 # Better than perfect token match
str1 = fix_buggy_characters(replace_punctuation(str1))
str2 = fix_buggy_characters(replace_punctuation(str2))
if str1 == str2:
return 2.0
if " " in str1 or " " in str2:
str1_split = str1.split(" ")
str2_split = str2.split(" ")
overlap = list(set(str1_split) & set(str2_split))
return len(overlap) / max(len(str1_split), len(str2_split))
else:
if str1 == str2:
return 1.0
else:
return 0.0
class QuestionAnsweringTrainer(Seq2SeqTrainer):
def __init__(self, *args, tokenizer,eval_examples=None,answer_column_name='answers',dataset_name='squad', **kwargs):
super().__init__(*args, **kwargs)
self.eval_examples = eval_examples
self.answer_column_name = answer_column_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
if dataset_name in ['squad', "quoref", 'ropes', 'drop', 'nqopen']:
self.post_process_function = self._post_process_squad
elif dataset_name in ['boolq', 'multirc', 'boolq_np']:
self.post_process_function = self._post_process_boolq
elif dataset_name== 'narrativeqa':
self.post_process_function = self._post_process_narrative_qa
elif dataset_name in ['race', 'mctest', 'social_iqa']:
self.post_process_function = self._post_process_race
elif dataset_name in ["squad_v2", "newsqa"]:
self.post_process_function = self._post_process_squad_v2
elif dataset_name in ['openbookqa']:
self.post_process_function = self._post_process_openbookqa
elif dataset_name in ['dream']:
self.post_process_function = self._post_process_dream
else:
print(f"{dataset_name} has not QuestionAnsweringTrainer.post_process_function!")
raise NotImplementedError
def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
inputs["train"]=True
else:
labels = None
outputs,sim_loss = model(**inputs)
# print(model.encoder.)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
loss = self.label_smoother(outputs, labels)
else:
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
loss += sim_loss*0.1
# print(sim_loss)
return (loss, outputs) if return_outputs else loss
def _post_process_squad(
self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval"
,version_2_with_negative=False):
# Decode the predicted tokens.
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
# Build a map example to its corresponding features.
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)}
predictions = {}
# Let's loop over all the examples!
for example_index, example in enumerate(examples):
# This is the index of the feature associated to the current example.
try:
feature_index = feature_per_example[example_index]
except:
print(feature_per_example)
print(example_index)
print(len(feature_per_example))
predictions[example["id"]] = decoded_preds[feature_index]
# Format the result to the format the metric expects.
if version_2_with_negative:
formatted_predictions = [
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
]
else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
def _post_process_squad_v2(
self, examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval"):
# Decode the predicted tokens.
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
# Build a map example to its corresponding features.
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)}
predictions = {}
# Let's loop over all the examples!
for example_index, example in enumerate(examples):
# This is the index of the feature associated to the current example.
feature_index = feature_per_example[example_index]
predictions[example["id"]] = decoded_preds[feature_index]
# Format the result to the format the metric expects.
formatted_predictions = [
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
]
references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
@classmethod
def postprocess_text(cls,preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels
def _post_process_boolq(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
# preds = [" ".join(pred) for pred in decoded_preds]
preds = decoded_preds
outputs = []
# for pred in preds:
# if 'True' == pred:
# outputs.append(1)
# elif 'False' == pred:
# outputs.append(0)
# else:
# outputs.append(0)
# references = [1 if ex['answer'] else 0 for ex in examples]
references = []
for pred, ref in zip(preds, examples):
ans = ref['answer']
references.append(1 if ans else 0)
if pred.lower() in ['true', 'false']:
outputs.append(1 if pred.lower() == 'true' else 0)
else:
# generate a wrong prediction if pred is not true/false
outputs.append(0 if ans else 1)
assert(len(references)==len(outputs))
formatted_predictions = []
return EvalPrediction(predictions=outputs, label_ids=references)
def _post_process_narrative_qa(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
# preds = [" ".join(pred) for pred in decoded_preds]
preds = decoded_preds
references = [exp['answers'][0]['text'] for exp in examples]
formatted_predictions = []
preds, references = self.postprocess_text(preds,references)
assert(len(preds)==len(references))
return EvalPrediction(predictions=preds, label_ids=references)
def _post_process_drop(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
# preds = [" ".join(pred) for pred in decoded_preds]
preds = decoded_preds
references = [exp['answers'][0]['text'] for exp in examples]
formatted_predictions = []
preds, references = self.postprocess_text(preds, references)
assert (len(preds) == len(references))
return EvalPrediction(predictions=preds, label_ids=references)
def _post_process_race(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
references = [{"answer": ex['answer'], 'options':ex['options']} for ex in examples]
assert(len(references)==len(decoded_preds))
gold_ids , pred_ids = [],[]
for prediction, reference in zip(decoded_preds, references):
# reference = json.loads(reference)
gold = int(ord(reference['answer'].strip()) - ord('A'))
options = reference['options']
prediction = prediction.replace("\n", "").strip()
options = [opt.strip() for opt in options if len(opt) > 0]
# print('options',options,type(options))
# print('prediction',prediction)
# print('answer',gold)
scores = [score_string_similarity(opt, prediction) for opt in options]
max_idx = np.argmax(scores)
gold_ids.append(gold)
pred_ids.append(max_idx)
# selected_ans = chr(ord('A') + max_idx)
# print(len(references),len(decoded_preds))
return EvalPrediction(predictions=pred_ids,label_ids = gold_ids)
def _post_process_openbookqa(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
references = [{"answer": ex['answerKey'], 'options': ex['choices']['text']} for ex in examples]
assert len(references) == len(decoded_preds)
gold_ids, pred_ids = [], []
for prediction, reference in zip(decoded_preds, references):
# reference = json.loads(reference)
gold = int(ord(reference['answer'].strip()) - ord('A'))
options = reference['options']
prediction = prediction.replace("\n", "").strip()
options = [opt.strip() for opt in options if len(opt) > 0]
scores = [score_string_similarity(opt, prediction) for opt in options]
max_idx = np.argmax(scores)
gold_ids.append(gold)
pred_ids.append(max_idx)
return EvalPrediction(predictions=pred_ids,label_ids = gold_ids)
def _post_process_dream(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
references = [{"answer": ex['answer'], 'options': ex['choice']} for ex in examples]
assert(len(references)==len(decoded_preds))
gold_ids, pred_ids = [],[]
for prediction, reference in zip(decoded_preds, references):
gold = reference['options'].index(reference['answer'])
options = reference['options']
prediction = prediction.replace("\n", "").strip()
options = [opt.strip() for opt in options if len(opt) > 0]
scores = [score_string_similarity(opt, prediction) for opt in options]
max_idx = np.argmax(scores)
gold_ids.append(gold)
pred_ids.append(max_idx)
return EvalPrediction(predictions=pred_ids, label_ids = gold_ids)
def evaluate(self, eval_dataset=None, eval_examples=None, tokenizer=None,ignore_keys=None, metric_key_prefix: str = "eval",
max_length: Optional[int] = None,num_beams: Optional[int] = None):
self._memory_tracker.start()
self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
eval_dataloader = self.get_eval_dataloader(eval_dataset)
eval_examples = self.eval_examples if eval_examples is None else eval_examples
# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try:
output = eval_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
)
finally:
self.compute_metrics = compute_metrics
if self.post_process_function is not None and self.compute_metrics is not None:
eval_preds = self.post_process_function(eval_examples, eval_dataset, output,tokenizer)
metrics = self.compute_metrics(eval_preds)
if self.dataset_name=='narrativeqa':
metrics = {key: value * 100 for key, value in metrics.items()}
metrics = {k: round(v, 4) for k, v in metrics.items()}
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
print(metrics)
else:
metrics = {}
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
self._memory_tracker.stop_and_update_metrics(metrics)
return metrics
def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save except FullyShardedDDP.
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
# Save model checkpoint
save_only_best = True
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is not None and trial is not None:
if self.hp_search_backend == HPSearchBackend.OPTUNA:
run_id = trial.number
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune
run_id = tune.get_trial_id()
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
run_id = trial.id
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
run_dir = os.path.join(self.args.output_dir, run_name)
else:
run_dir = self.args.output_dir
self.store_flos()
output_dir = os.path.join(run_dir, checkpoint_folder)
if not save_only_best:
self.save_model(output_dir)
if self.deepspeed:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_fp16_weights_on_model_save` is True
self.deepspeed.save_checkpoint(output_dir)
# Save optimizer and scheduler
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer.consolidate_state_dict()
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
if smp.dp_rank() == 0:
# Consolidate the state dict on all processed of dp_rank 0
opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process
if self.args.should_save and not save_only_best:
torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.deepspeed and not save_only_best:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
self.do_grad_scaling=True
# if self.do_grad_scaling:
# torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics[metric_to_check]
operator = np.greater if self.args.greater_is_better else np.less
if (
self.state.best_metric is None
or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)
):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
best_dir = os.path.join(run_dir,'best-checkpoint')
# if not os.path.exists(best_dir):
# os.mkdir(best_dir,exist_ok=True)
if self.args.should_save:
self.save_model(best_dir)
self.state.save_to_json(os.path.join(best_dir, TRAINER_STATE_NAME))
# Save the Trainer state
if self.args.should_save and not save_only_best:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
# Save RNG state in non-distributed training
rng_states = {
"python": random.getstate(),
"numpy": np.random.get_state(),
"cpu": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
if self.args.local_rank == -1:
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
else:
rng_states["cuda"] = torch.cuda.random.get_rng_state()
if is_torch_tpu_available():
rng_states["xla"] = xm.get_rng_state()
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
if local_rank == -1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
# Maybe delete some older checkpoints.
if self.args.should_save and not save_only_best:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
# def predict(self, predict_dataset=None, predict_examples=None, tokenizer=None,ignore_keys=None, metric_key_prefix: str = "predict",
# max_length: Optional[int] = None,num_beams: Optional[int] = None):
# self._memory_tracker.start()
# self._max_length = max_length if max_length is not None else self.args.generation_max_length
# self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
# predict_dataset = self.predict_dataset if predict_dataset is None else predict_dataset
# predict_dataloader = self.get_eval_dataloader(predict_dataset)
# predict_examples = predict_examples
#
# # Temporarily disable metric computation, we will do it in the loop here.
# compute_metrics = self.compute_metrics
# self.compute_metrics = None
# eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
# try:
# output = eval_loop(
# predict_dataloader,
# description="Predict",
# # No point gathering the predictions if there are no metrics, otherwise we defer to
# # self.args.prediction_loss_only
# prediction_loss_only=True if compute_metrics is None else None,
# ignore_keys=ignore_keys,
# )
# finally:
# self.compute_metrics = compute_metrics
#
# if self.post_process_function is not None and self.compute_metrics is not None:
# predict_preds = self.post_process_function(predict_examples, predict_dataset, output,tokenizer)
# metrics = self.compute_metrics(predict_preds)
# if self.dataset_name=='narrativeqa':
# metrics = {key: value.mid.fmeasure * 100 for key, value in metrics.items()}
# metrics = {k: round(v, 4) for k, v in metrics.items()}
# # Prefix all keys with metric_key_prefix + '_'
# for key in list(metrics.keys()):
# if not key.startswith(f"{metric_key_prefix}_"):
# metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
#
# self.log(metrics)
# else:
# metrics = {}
#
# self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
# self._memory_tracker.stop_and_update_metrics(metrics)
# return metrics

View File

@ -0,0 +1,105 @@
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Accuracy metric."""
from sklearn.metrics import accuracy_score
import datasets
_DESCRIPTION = """
Accuracy is the proportion of correct predictions among the total number of cases processed. It can be computed with:
Accuracy = (TP + TN) / (TP + TN + FP + FN)
Where:
TP: True positive
TN: True negative
FP: False positive
FN: False negative
"""
_KWARGS_DESCRIPTION = """
Args:
predictions (`list` of `int`): Predicted labels.
references (`list` of `int`): Ground truth labels.
normalize (`boolean`): If set to False, returns the number of correctly classified samples. Otherwise, returns the fraction of correctly classified samples. Defaults to True.
sample_weight (`list` of `float`): Sample weights Defaults to None.
Returns:
accuracy (`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher score means higher accuracy.
Examples:
Example 1-A simple example
>>> accuracy_metric = datasets.load_metric("accuracy")
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0])
>>> print(results)
{'accuracy': 0.5}
Example 2-The same as Example 1, except with `normalize` set to `False`.
>>> accuracy_metric = datasets.load_metric("accuracy")
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], normalize=False)
>>> print(results)
{'accuracy': 3.0}
Example 3-The same as Example 1, except with `sample_weight` set.
>>> accuracy_metric = datasets.load_metric("accuracy")
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], sample_weight=[0.5, 2, 0.7, 0.5, 9, 0.4])
>>> print(results)
{'accuracy': 0.8778625954198473}
"""
_CITATION = """
@article{scikit-learn,
title={Scikit-learn: Machine Learning in {P}ython},
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
journal={Journal of Machine Learning Research},
volume={12},
pages={2825--2830},
year={2011}
}
"""
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Accuracy(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("int32")),
"references": datasets.Sequence(datasets.Value("int32")),
}
if self.config_name == "multilabel"
else {
"predictions": datasets.Value("int32"),
"references": datasets.Value("int32"),
}
),
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
)
def _compute(self, predictions, references, normalize=True, sample_weight=None):
return {
"accuracy": float(
accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight)
)
}

View File

@ -0,0 +1,126 @@
# Copyright 2020 The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BLEU metric. """
import datasets
from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
_CITATION = """\
@INPROCEEDINGS{Papineni02bleu:a,
author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu},
title = {BLEU: a Method for Automatic Evaluation of Machine Translation},
booktitle = {},
year = {2002},
pages = {311--318}
}
@inproceedings{lin-och-2004-orange,
title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation",
author = "Lin, Chin-Yew and
Och, Franz Josef",
booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics",
month = "aug 23{--}aug 27",
year = "2004",
address = "Geneva, Switzerland",
publisher = "COLING",
url = "https://www.aclweb.org/anthology/C04-1072",
pages = "501--507",
}
"""
_DESCRIPTION = """\
BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another.
Quality is considered to be the correspondence between a machine's output and that of a human: "the closer a machine translation is to a professional human translation,
the better it is" this is the central idea behind BLEU. BLEU was one of the first metrics to claim a high correlation with human judgements of quality, and
remains one of the most popular automated and inexpensive metrics.
Scores are calculated for individual translated segmentsgenerally sentencesby comparing them with a set of good quality reference translations.
Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. Intelligibility or grammatical correctness
are not taken into account[citation needed].
BLEU's output is always a number between 0 and 1. This value indicates how similar the candidate text is to the reference texts, with values closer to 1
representing more similar texts. Few human translations will attain a score of 1, since this would indicate that the candidate is identical to one of the
reference translations. For this reason, it is not necessary to attain a score of 1. Because there are more opportunities to match, adding additional
reference translations will increase the BLEU score.
"""
_KWARGS_DESCRIPTION = """
Computes BLEU score of translated segments against one or more references.
Args:
predictions: list of translations to score.
Each translation should be tokenized into a list of tokens.
references: list of lists of references for each translation.
Each reference should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
smooth: Whether or not to apply Lin et al. 2004 smoothing.
Returns:
'bleu': bleu score,
'precisions': geometric mean of n-gram precisions,
'brevity_penalty': brevity penalty,
'length_ratio': ratio of lengths,
'translation_length': translation_length,
'reference_length': reference_length
Examples:
>>> predictions = [
... ["hello", "there", "general", "kenobi"], # tokenized prediction of the first sample
... ["foo", "bar", "foobar"] # tokenized prediction of the second sample
... ]
>>> references = [
... [["hello", "there", "general", "kenobi"], ["hello", "there", "!"]], # tokenized references for the first sample (2 references)
... [["foo", "bar", "foobar"]] # tokenized references for the second sample (1 reference)
... ]
>>> bleu = datasets.load_metric("bleu")
>>> results = bleu.compute(predictions=predictions, references=references)
>>> print(results["bleu"])
1.0
"""
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Bleu(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("string", id="token"), id="sequence"),
"references": datasets.Sequence(
datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), id="references"
),
}
),
codebase_urls=["https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py"],
reference_urls=[
"https://en.wikipedia.org/wiki/BLEU",
"https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213",
],
)
def _compute(self, predictions, references, max_order=4, smooth=False):
score = compute_bleu(
reference_corpus=references, translation_corpus=predictions, max_order=max_order, smooth=smooth
)
(bleu, precisions, bp, ratio, translation_length, reference_length) = score
return {
"bleu": bleu,
"precisions": precisions,
"brevity_penalty": bp,
"length_ratio": ratio,
"translation_length": translation_length,
"reference_length": reference_length,
}

View File

@ -0,0 +1,219 @@
# Copyright 2020 The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BLEU metric. """
import datasets
#from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
_CITATION = """\
@INPROCEEDINGS{Papineni02bleu:a,
author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu},
title = {BLEU: a Method for Automatic Evaluation of Machine Translation},
booktitle = {},
year = {2002},
pages = {311--318}
}
@inproceedings{lin-och-2004-orange,
title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation",
author = "Lin, Chin-Yew and
Och, Franz Josef",
booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics",
month = "aug 23{--}aug 27",
year = "2004",
address = "Geneva, Switzerland",
publisher = "COLING",
url = "https://www.aclweb.org/anthology/C04-1072",
pages = "501--507",
}
"""
_DESCRIPTION = """\
BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another.
Quality is considered to be the correspondence between a machine's output and that of a human: "the closer a machine translation is to a professional human translation,
the better it is" this is the central idea behind BLEU. BLEU was one of the first metrics to claim a high correlation with human judgements of quality, and
remains one of the most popular automated and inexpensive metrics.
Scores are calculated for individual translated segmentsgenerally sentencesby comparing them with a set of good quality reference translations.
Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. Intelligibility or grammatical correctness
are not taken into account[citation needed].
BLEU's output is always a number between 0 and 1. This value indicates how similar the candidate text is to the reference texts, with values closer to 1
representing more similar texts. Few human translations will attain a score of 1, since this would indicate that the candidate is identical to one of the
reference translations. For this reason, it is not necessary to attain a score of 1. Because there are more opportunities to match, adding additional
reference translations will increase the BLEU score.
"""
_KWARGS_DESCRIPTION = """
Computes BLEU score of translated segments against one or more references.
Args:
predictions: list of translations to score.
Each translation should be tokenized into a list of tokens.
references: list of lists of references for each translation.
Each reference should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
smooth: Whether or not to apply Lin et al. 2004 smoothing.
Returns:
'bleu': bleu score,
'precisions': geometric mean of n-gram precisions,
'brevity_penalty': brevity penalty,
'length_ratio': ratio of lengths,
'translation_length': translation_length,
'reference_length': reference_length
Examples:
>>> predictions = [
... ["hello", "there", "general", "kenobi"], # tokenized prediction of the first sample
... ["foo", "bar", "foobar"] # tokenized prediction of the second sample
... ]
>>> references = [
... [["hello", "there", "general", "kenobi"], ["hello", "there", "!"]], # tokenized references for the first sample (2 references)
... [["foo", "bar", "foobar"]] # tokenized references for the second sample (1 reference)
... ]
>>> bleu = datasets.load_metric("bleu")
>>> results = bleu.compute(predictions=predictions, references=references)
>>> print(results["bleu"])
1.0
"""
import collections
import math
def _get_ngrams(segment, max_order):
"""Extracts all n-grams upto a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams upto max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i:i+order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus, translation_corpus, max_order=1,
smooth=False):
"""Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of lists of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
smooth: Whether or not to apply Lin et al. 2004 smoothing.
Returns:
3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
precisions and brevity penalty.
"""
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
reference_length = 0
translation_length = 0
for (references, translation) in zip(reference_corpus,
translation_corpus):
# translation = references[0]
translation = [item.lower() for item in translation]
reference_length += min(len(r) for r in references)
translation_length += len(translation)
merged_ref_ngram_counts = collections.Counter()
for reference in references:
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
# print(merged_ref_ngram_counts)
translation_ngram_counts = _get_ngrams(translation, max_order)
overlap = translation_ngram_counts & merged_ref_ngram_counts
for ngram in overlap:
matches_by_order[len(ngram)-1] += overlap[ngram]
for order in range(1, max_order+1):
possible_matches = len(translation) - order + 1
if possible_matches > 0:
possible_matches_by_order[order-1] += possible_matches
precisions = [0] * max_order
for i in range(0, max_order):
if smooth:
precisions[i] = ((matches_by_order[i] + 1.) /
(possible_matches_by_order[i] + 1.))
else:
if possible_matches_by_order[i] > 0:
precisions[i] = (float(matches_by_order[i]) /
possible_matches_by_order[i])
else:
precisions[i] = 0.0
if min(precisions) > 0:
p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
geo_mean = math.exp(p_log_sum)
else:
geo_mean = 0
ratio = float(translation_length) / reference_length
if ratio > 1.0:
bp = 1.
else:
bp = math.exp(1 - 1. / ratio)
bleu = geo_mean * bp
return (bleu, precisions, bp, ratio, translation_length, reference_length)
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Bleu(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("string", id="token"), id="sequence"),
"references": datasets.Sequence(
datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), id="references"
),
}
),
codebase_urls=["https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py"],
reference_urls=[
"https://en.wikipedia.org/wiki/BLEU",
"https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213",
],
)
def _compute(self, predictions, references, max_order=4, smooth=False):
score = compute_bleu(
reference_corpus=references, translation_corpus=predictions, max_order=1, smooth=smooth
)
(bleu, precisions, bp, ratio, translation_length, reference_length) = score
return {
"bleu": bleu,
"precisions": precisions,
"brevity_penalty": bp,
"length_ratio": ratio,
"translation_length": translation_length,
"reference_length": reference_length,
}

View File

@ -0,0 +1,133 @@
# Copyright 2020 The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" ROUGE metric from Google Research github repo. """
# The dependencies in https://github.com/google-research/google-research/blob/master/rouge/requirements.txt
import absl # Here to have a nice missing dependency error message early on
import nltk # Here to have a nice missing dependency error message early on
import numpy # Here to have a nice missing dependency error message early on
import six # Here to have a nice missing dependency error message early on
from .rouge_score import RougeScorer
from .scoring import *
from .rouge_tokenizers import *
import datasets
_CITATION = """\
@inproceedings{lin-2004-rouge,
title = "{ROUGE}: A Package for Automatic Evaluation of Summaries",
author = "Lin, Chin-Yew",
booktitle = "Text Summarization Branches Out",
month = jul,
year = "2004",
address = "Barcelona, Spain",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/W04-1013",
pages = "74--81",
}
"""
_DESCRIPTION = """\
ROUGE, or Recall-Oriented Understudy for Gisting Evaluation, is a set of metrics and a software package used for
evaluating automatic summarization and machine translation software in natural language processing.
The metrics compare an automatically produced summary or translation against a reference or a set of references (human-produced) summary or translation.
Note that ROUGE is case insensitive, meaning that upper case letters are treated the same way as lower case letters.
This metrics is a wrapper around Google Research reimplementation of ROUGE:
https://github.com/google-research/google-research/tree/master/rouge
"""
_KWARGS_DESCRIPTION = """
Calculates average rouge scores for a list of hypotheses and references
Args:
predictions: list of predictions to score. Each prediction
should be a string with tokens separated by spaces.
references: list of reference for each prediction. Each
reference should be a string with tokens separated by spaces.
rouge_types: A list of rouge types to calculate.
Valid names:
`"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
`"rougeL"`: Longest common subsequence based scoring.
`"rougeLSum"`: rougeLsum splits text using `"\n"`.
See details in https://github.com/huggingface/datasets/issues/617
use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
use_aggregator: Return aggregates if this is set to True
Returns:
rouge1: rouge_1 (precision, recall, f1),
rouge2: rouge_2 (precision, recall, f1),
rougeL: rouge_l (precision, recall, f1),
rougeLsum: rouge_lsum (precision, recall, f1)
Examples:
>>> rouge = datasets.load_metric('rouge')
>>> predictions = ["hello there", "general kenobi"]
>>> references = ["hello there", "general kenobi"]
>>> results = rouge.compute(predictions=predictions, references=references)
>>> print(list(results.keys()))
['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
>>> print(results["rouge1"])
AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))
>>> print(results["rouge1"].mid.fmeasure)
1.0
"""
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Rouge(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Value("string", id="sequence"),
}
),
codebase_urls=["https://github.com/google-research/google-research/tree/master/rouge"],
reference_urls=[
"https://en.wikipedia.org/wiki/ROUGE_(metric)",
"https://github.com/google-research/google-research/tree/master/rouge",
],
)
def _compute(self, predictions, references, rouge_types=None, use_aggregator=True, use_stemmer=False):
if rouge_types is None:
rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
scorer = RougeScorer(rouge_types=rouge_types, use_stemmer=use_stemmer)
if use_aggregator:
aggregator = BootstrapAggregator()
else:
scores = []
for ref, pred in zip(references, predictions):
score = scorer.score(ref, pred)
if use_aggregator:
aggregator.add_scores(score)
else:
scores.append(score)
if use_aggregator:
result = aggregator.aggregate()
for key in rouge_types:
result[key] = result[key].mid.fmeasure
else:
result = {}
for key in scores[0]:
result[key] = list(score[key] for score in scores)
return result

View File

@ -0,0 +1,318 @@
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Computes rouge scores between two text blobs.
Implementation replicates the functionality in the original ROUGE package. See:
Lin, Chin-Yew. ROUGE: a Package for Automatic Evaluation of Summaries. In
Proceedings of the Workshop on Text Summarization Branches Out (WAS 2004),
Barcelona, Spain, July 25 - 26, 2004.
Default options are equivalent to running:
ROUGE-1.5.5.pl -e data -n 2 -a settings.xml
Or with use_stemmer=True:
ROUGE-1.5.5.pl -m -e data -n 2 -a settings.xml
In these examples settings.xml lists input files and formats.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
from absl import logging
import nltk
import numpy as np
import six
from six.moves import map
from six.moves import range
from .scoring import *
from .rouge_tokenizers import *
class RougeScorer(BaseScorer):
"""Calculate rouges scores between two blobs of text.
Sample usage:
scorer = RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
scores = scorer.score('The quick brown fox jumps over the lazy dog',
'The quick brown dog jumps on the log.')
"""
def __init__(self, rouge_types, use_stemmer=False, split_summaries=False,
tokenizer=None):
"""Initializes a new RougeScorer.
Valid rouge types that can be computed are:
rougen (e.g. rouge1, rouge2): n-gram based scoring.
rougeL: Longest common subsequence based scoring.
Args:
rouge_types: A list of rouge types to calculate.
use_stemmer: Bool indicating whether Porter stemmer should be used to
strip word suffixes to improve matching. This arg is used in the
DefaultTokenizer, but other tokenizers might or might not choose to
use this.
split_summaries: whether to add newlines between sentences for rougeLsum
tokenizer: Tokenizer object which has a tokenize() method.
Returns:
A dict mapping rouge types to Score tuples.
"""
self.rouge_types = rouge_types
if tokenizer:
self._tokenizer = tokenizer
else:
self._tokenizer = DefaultTokenizer(use_stemmer)
logging.info("Using default tokenizer.")
self._split_summaries = split_summaries
def score_multi(self, targets, prediction):
"""Calculates rouge scores between targets and prediction.
The target with the maximum f-measure is used for the final score for
each score type..
Args:
targets: list of texts containing the targets
prediction: Text containing the predicted text.
Returns:
A dict mapping each rouge type to a Score object.
Raises:
ValueError: If an invalid rouge type is encountered.
"""
score_dicts = [self.score(t, prediction) for t in targets]
max_score = {}
for k in self.rouge_types:
index = np.argmax([s[k].fmeasure for s in score_dicts])
max_score[k] = score_dicts[index][k]
return max_score
def score(self, target, prediction):
"""Calculates rouge scores between the target and prediction.
Args:
target: Text containing the target (ground truth) text,
or if a list
prediction: Text containing the predicted text.
Returns:
A dict mapping each rouge type to a Score object.
Raises:
ValueError: If an invalid rouge type is encountered.
"""
# Pre-compute target tokens and prediction tokens for use by different
# types, except if only "rougeLsum" is requested.
if len(self.rouge_types) == 1 and self.rouge_types[0] == "rougeLsum":
target_tokens = None
prediction_tokens = None
else:
target_tokens = self._tokenizer.tokenize(target)
prediction_tokens = self._tokenizer.tokenize(prediction)
result = {}
for rouge_type in self.rouge_types:
if rouge_type == "rougeL":
# Rouge from longest common subsequences.
scores = _score_lcs(target_tokens, prediction_tokens)
elif rouge_type == "rougeLsum":
# Note: Does not support multi-line text.
def get_sents(text):
if self._split_summaries:
sents = nltk.sent_tokenize(text)
else:
# Assume sentences are separated by newline.
sents = six.ensure_str(text).split("\n")
sents = [x for x in sents if len(x)]
return sents
target_tokens_list = [
self._tokenizer.tokenize(s) for s in get_sents(target)]
prediction_tokens_list = [
self._tokenizer.tokenize(s) for s in get_sents(prediction)]
scores = _summary_level_lcs(target_tokens_list,
prediction_tokens_list)
elif re.match(r"rouge[0-9]$", six.ensure_str(rouge_type)):
# Rouge from n-grams.
n = int(rouge_type[5:])
if n <= 0:
raise ValueError("rougen requires positive n: %s" % rouge_type)
target_ngrams = _create_ngrams(target_tokens, n)
prediction_ngrams = _create_ngrams(prediction_tokens, n)
scores = _score_ngrams(target_ngrams, prediction_ngrams)
else:
raise ValueError("Invalid rouge type: %s" % rouge_type)
result[rouge_type] = scores
return result
def _create_ngrams(tokens, n):
"""Creates ngrams from the given list of tokens.
Args:
tokens: A list of tokens from which ngrams are created.
n: Number of tokens to use, e.g. 2 for bigrams.
Returns:
A dictionary mapping each bigram to the number of occurrences.
"""
ngrams = collections.Counter()
for ngram in (tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)):
ngrams[ngram] += 1
return ngrams
def _score_lcs(target_tokens, prediction_tokens):
"""Computes LCS (Longest Common Subsequence) rouge scores.
Args:
target_tokens: Tokens from the target text.
prediction_tokens: Tokens from the predicted text.
Returns:
A Score object containing computed scores.
"""
if not target_tokens or not prediction_tokens:
return Score(precision=0, recall=0, fmeasure=0)
# Compute length of LCS from the bottom up in a table (DP appproach).
lcs_table = _lcs_table(target_tokens, prediction_tokens)
lcs_length = lcs_table[-1][-1]
precision = lcs_length / len(prediction_tokens)
recall = lcs_length / len(target_tokens)
fmeasure = func_fmeasure(precision, recall)
return Score(precision=precision, recall=recall, fmeasure=fmeasure)
def _lcs_table(ref, can):
"""Create 2-d LCS score table."""
rows = len(ref)
cols = len(can)
lcs_table = [[0] * (cols + 1) for _ in range(rows + 1)]
for i in range(1, rows + 1):
for j in range(1, cols + 1):
if ref[i - 1] == can[j - 1]:
lcs_table[i][j] = lcs_table[i - 1][j - 1] + 1
else:
lcs_table[i][j] = max(lcs_table[i - 1][j], lcs_table[i][j - 1])
return lcs_table
def _backtrack_norec(t, ref, can):
"""Read out LCS."""
i = len(ref)
j = len(can)
lcs = []
while i > 0 and j > 0:
if ref[i - 1] == can[j - 1]:
lcs.insert(0, i-1)
i -= 1
j -= 1
elif t[i][j - 1] > t[i - 1][j]:
j -= 1
else:
i -= 1
return lcs
def _summary_level_lcs(ref_sent, can_sent):
"""ROUGE: Summary-level LCS, section 3.2 in ROUGE paper.
Args:
ref_sent: list of tokenized reference sentences
can_sent: list of tokenized candidate sentences
Returns:
summary level ROUGE score
"""
if not ref_sent or not can_sent:
return Score(precision=0, recall=0, fmeasure=0)
m = sum(map(len, ref_sent))
n = sum(map(len, can_sent))
if not n or not m:
return Score(precision=0, recall=0, fmeasure=0)
# get token counts to prevent double counting
token_cnts_r = collections.Counter()
token_cnts_c = collections.Counter()
for s in ref_sent:
# s is a list of tokens
token_cnts_r.update(s)
for s in can_sent:
token_cnts_c.update(s)
hits = 0
for r in ref_sent:
lcs = _union_lcs(r, can_sent)
# Prevent double-counting:
# The paper describes just computing hits += len(_union_lcs()),
# but the implementation prevents double counting. We also
# implement this as in version 1.5.5.
for t in lcs:
if token_cnts_c[t] > 0 and token_cnts_r[t] > 0:
hits += 1
token_cnts_c[t] -= 1
token_cnts_r[t] -= 1
recall = hits / m
precision = hits / n
fmeasure = func_fmeasure(precision, recall)
return Score(precision=precision, recall=recall, fmeasure=fmeasure)
def _union_lcs(ref, c_list):
"""Find union LCS between a ref sentence and list of candidate sentences.
Args:
ref: list of tokens
c_list: list of list of indices for LCS into reference summary
Returns:
List of tokens in ref representing union LCS.
"""
lcs_list = [lcs_ind(ref, c) for c in c_list]
return [ref[i] for i in _find_union(lcs_list)]
def _find_union(lcs_list):
"""Finds union LCS given a list of LCS."""
return sorted(list(set().union(*lcs_list)))
def lcs_ind(ref, can):
"""Returns one of the longest lcs."""
t = _lcs_table(ref, can)
return _backtrack_norec(t, ref, can)
def _score_ngrams(target_ngrams, prediction_ngrams):
"""Compute n-gram based rouge scores.
Args:
target_ngrams: A Counter object mapping each ngram to number of
occurrences for the target text.
prediction_ngrams: A Counter object mapping each ngram to number of
occurrences for the prediction text.
Returns:
A Score object containing computed scores.
"""
intersection_ngrams_count = 0
for ngram in six.iterkeys(target_ngrams):
intersection_ngrams_count += min(target_ngrams[ngram],
prediction_ngrams[ngram])
target_ngrams_count = sum(target_ngrams.values())
prediction_ngrams_count = sum(prediction_ngrams.values())
precision = intersection_ngrams_count / max(prediction_ngrams_count, 1)
recall = intersection_ngrams_count / max(target_ngrams_count, 1)
fmeasure = func_fmeasure(precision, recall)
return Score(precision=precision, recall=recall, fmeasure=fmeasure)

View File

@ -0,0 +1,90 @@
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Library containing Tokenizer definitions.
The RougeScorer class can be instantiated with the tokenizers defined here. New
tokenizers can be defined by creating a subclass of the Tokenizer abstract class
and overriding the tokenize() method.
"""
import abc
from nltk.stem import porter
import re
import six
# Pre-compile regexes that are use often
NON_ALPHANUM_PATTERN = r"[^a-z0-9]+"
NON_ALPHANUM_RE = re.compile(NON_ALPHANUM_PATTERN)
SPACES_PATTERN = r"\s+"
SPACES_RE = re.compile(SPACES_PATTERN)
VALID_TOKEN_PATTERN = r"^[a-z0-9]+$"
VALID_TOKEN_RE = re.compile(VALID_TOKEN_PATTERN)
def _tokenize(text, stemmer):
"""Tokenize input text into a list of tokens.
This approach aims to replicate the approach taken by Chin-Yew Lin in
the original ROUGE implementation.
Args:
text: A text blob to tokenize.
stemmer: An optional stemmer.
Returns:
A list of string tokens extracted from input text.
"""
# Convert everything to lowercase.
text = text.lower()
# Replace any non-alpha-numeric characters with spaces.
text = NON_ALPHANUM_RE.sub(" ", six.ensure_str(text))
tokens = SPACES_RE.split(text)
if stemmer:
# Only stem words more than 3 characters long.
tokens = [six.ensure_str(stemmer.stem(x)) if len(x) > 3 else x
for x in tokens]
# One final check to drop any empty or invalid tokens.
tokens = [x for x in tokens if VALID_TOKEN_RE.match(x)]
return tokens
class Tokenizer(abc.ABC):
"""Abstract base class for a tokenizer.
Subclasses of Tokenizer must implement the tokenize() method.
"""
@abc.abstractmethod
def tokenize(self, text):
raise NotImplementedError("Tokenizer must override tokenize() method")
class DefaultTokenizer(Tokenizer):
"""Default tokenizer which tokenizes on whitespace."""
def __init__(self, use_stemmer=False):
"""Constructor for DefaultTokenizer.
Args:
use_stemmer: boolean, indicating whether Porter stemmer should be used to
strip word suffixes to improve matching.
"""
self._stemmer = porter.PorterStemmer() if use_stemmer else None
def tokenize(self, text):
return _tokenize(text, self._stemmer)

View File

@ -0,0 +1,158 @@
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library for scoring and evaluation of text samples.
Aggregation functions use bootstrap resampling to compute confidence intervals
as per the original ROUGE perl implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import collections
from typing import Dict
import numpy as np
import six
from six.moves import range
class Score(
collections.namedtuple("Score", ["precision", "recall", "fmeasure"])):
"""Tuple containing precision, recall, and f-measure values."""
class BaseScorer(object, metaclass=abc.ABCMeta):
"""Base class for Scorer objects."""
@abc.abstractmethod
def score(self, target, prediction):
"""Calculates score between the target and prediction.
Args:
target: Text containing the target (ground truth) text.
prediction: Text containing the predicted text.
Returns:
A dict mapping each score_type (string) to Score object.
"""
class AggregateScore(
collections.namedtuple("AggregateScore", ["low", "mid", "high"])):
"""Tuple containing confidence intervals for scores."""
class BootstrapAggregator(object):
"""Aggregates scores to provide confidence intervals.
Sample usage:
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'])
aggregator = Aggregator()
aggregator.add_scores(scorer.score("one two three", "one two"))
aggregator.add_scores(scorer.score("one two five six", "seven eight"))
result = aggregator.aggregate()
print result
{'rougeL': AggregateScore(
low=Score(precision=0.0, recall=0.0, fmeasure=0.0),
mid=Score(precision=0.5, recall=0.33, fmeasure=0.40),
high=Score(precision=1.0, recall=0.66, fmeasure=0.80)),
'rouge1': AggregateScore(
low=Score(precision=0.0, recall=0.0, fmeasure=0.0),
mid=Score(precision=0.5, recall=0.33, fmeasure=0.40),
high=Score(precision=1.0, recall=0.66, fmeasure=0.80))}
"""
def __init__(self, confidence_interval=0.95, n_samples=1000):
"""Initializes a BootstrapAggregator object.
Args:
confidence_interval: Confidence interval to compute on the mean as a
decimal.
n_samples: Number of samples to use for bootstrap resampling.
Raises:
ValueError: If invalid argument is given.
"""
if confidence_interval < 0 or confidence_interval > 1:
raise ValueError("confidence_interval must be in range [0, 1]")
if n_samples <= 0:
raise ValueError("n_samples must be positive")
self._n_samples = n_samples
self._confidence_interval = confidence_interval
self._scores = collections.defaultdict(list)
def add_scores(self, scores):
"""Adds a sample for future aggregation.
Args:
scores: Dict mapping score_type strings to a namedtuple object/class
representing a score.
"""
for score_type, score in six.iteritems(scores):
self._scores[score_type].append(score)
def aggregate(self):
"""Aggregates scores previously added using add_scores.
Returns:
A dict mapping score_type to AggregateScore objects.
"""
result = {}
for score_type, scores in six.iteritems(self._scores):
# Stack scores into a 2-d matrix of (sample, measure).
score_matrix = np.vstack(tuple(scores))
# Percentiles are returned as (interval, measure).
percentiles = self._bootstrap_resample(score_matrix)
# Extract the three intervals (low, mid, high).
intervals = tuple(
(scores[0].__class__(*percentiles[j, :]) for j in range(3)))
result[score_type] = AggregateScore(
low=intervals[0], mid=intervals[1], high=intervals[2])
return result
def _bootstrap_resample(self, matrix):
"""Performs bootstrap resampling on a matrix of scores.
Args:
matrix: A 2-d matrix of (sample, measure).
Returns:
A 2-d matrix of (bounds, measure). There are three bounds: low (row 0),
mid (row 1) and high (row 2). Mid is always the mean, while low and high
bounds are specified by self._confidence_interval (which defaults to 0.95
meaning it will return the 2.5th and 97.5th percentiles for a 95%
confidence interval on the mean).
"""
# Matrix of (bootstrap sample, measure).
sample_mean = np.zeros((self._n_samples, matrix.shape[1]))
for i in range(self._n_samples):
sample_idx = np.random.choice(
np.arange(matrix.shape[0]), size=matrix.shape[0])
sample = matrix[sample_idx, :]
sample_mean[i, :] = np.mean(sample, axis=0)
# Take percentiles on the estimate of the mean using bootstrap samples.
# Final result is a (bounds, measure) matrix.
percentile_delta = (1 - self._confidence_interval) / 2
q = 100 * np.array([percentile_delta, 0.5, 1 - percentile_delta])
return np.percentile(sample_mean, q, axis=0)
def func_fmeasure(precision, recall):
"""Computes f-measure given precision and recall values."""
if precision + recall > 0:
return 2 * precision * recall / (precision + recall)
else:
return 0.0

View File

@ -0,0 +1,92 @@
""" Official evaluation script for v1.1 of the SQuAD dataset. """
import argparse
import json
import re
import string
import sys
from collections import Counter
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth):
return normalize_answer(prediction) == normalize_answer(ground_truth)
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def evaluate(dataset, predictions):
f1 = exact_match = total = 0
for article in dataset:
for paragraph in article["paragraphs"]:
for qa in paragraph["qas"]:
total += 1
if qa["id"] not in predictions:
message = "Unanswered question " + qa["id"] + " will receive score 0."
print(message, file=sys.stderr)
continue
ground_truths = list(map(lambda x: x["text"], qa["answers"]))
prediction = predictions[qa["id"]]
exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
exact_match = 100.0 * exact_match / total
f1 = 100.0 * f1 / total
return {"exact_match": exact_match, "f1": f1}
if __name__ == "__main__":
expected_version = "1.1"
parser = argparse.ArgumentParser(description="Evaluation for SQuAD " + expected_version)
parser.add_argument("dataset_file", help="Dataset file")
parser.add_argument("prediction_file", help="Prediction File")
args = parser.parse_args()
with open(args.dataset_file) as dataset_file:
dataset_json = json.load(dataset_file)
if dataset_json["version"] != expected_version:
print(
"Evaluation expects v-" + expected_version + ", but got dataset with v-" + dataset_json["version"],
file=sys.stderr,
)
dataset = dataset_json["data"]
with open(args.prediction_file) as prediction_file:
predictions = json.load(prediction_file)
print(json.dumps(evaluate(dataset, predictions)))

View File

@ -0,0 +1,109 @@
# Copyright 2020 The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" SQuAD metric. """
import datasets
from .evaluate import evaluate
_CITATION = """\
@inproceedings{Rajpurkar2016SQuAD10,
title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text},
author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang},
booktitle={EMNLP},
year={2016}
}
"""
_DESCRIPTION = """
This metric wrap the official scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD).
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by
crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span,
from the corresponding reading passage, or the question might be unanswerable.
"""
_KWARGS_DESCRIPTION = """
Computes SQuAD scores (F1 and EM).
Args:
predictions: List of question-answers dictionaries with the following key-values:
- 'id': id of the question-answer pair as given in the references (see below)
- 'prediction_text': the text of the answer
references: List of question-answers dictionaries with the following key-values:
- 'id': id of the question-answer pair (see above),
- 'answers': a Dict in the SQuAD dataset format
{
'text': list of possible texts for the answer, as a list of strings
'answer_start': list of start positions for the answer, as a list of ints
}
Note that answer_start values are not taken into account to compute the metric.
Returns:
'exact_match': Exact match (the normalized answer exactly match the gold answer)
'f1': The F-score of predicted tokens versus the gold answer
Examples:
>>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22'}]
>>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}]
>>> squad_metric = datasets.load_metric("squad")
>>> results = squad_metric.compute(predictions=predictions, references=references)
>>> print(results)
{'exact_match': 100.0, 'f1': 100.0}
"""
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Squad(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": {"id": datasets.Value("string"), "prediction_text": datasets.Value("string")},
"references": {
"id": datasets.Value("string"),
"answers": datasets.features.Sequence(
{
"text": datasets.Value("string"),
"answer_start": datasets.Value("int32"),
}
),
},
}
),
codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
)
def _compute(self, predictions, references):
pred_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions}
dataset = [
{
"paragraphs": [
{
"qas": [
{
"answers": [{"text": answer_text} for answer_text in ref["answers"]["text"]],
"id": ref["id"],
}
for ref in references
]
}
]
}
]
score = evaluate(dataset=dataset, predictions=pred_dict)
return score

View File

@ -0,0 +1,330 @@
"""Official evaluation script for SQuAD version 2.0.
In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
import argparse
import collections
import json
import os
import re
import string
import sys
import numpy as np
OPTS = None
def parse_args():
parser = argparse.ArgumentParser("Official evaluation script for SQuAD version 2.0.")
parser.add_argument("data_file", metavar="data.json", help="Input data JSON file.")
parser.add_argument("pred_file", metavar="pred.json", help="Model predictions.")
parser.add_argument(
"--out-file", "-o", metavar="eval.json", help="Write accuracy metrics to file (default is stdout)."
)
parser.add_argument(
"--na-prob-file", "-n", metavar="na_prob.json", help="Model estimates of probability of no answer."
)
parser.add_argument(
"--na-prob-thresh",
"-t",
type=float,
default=1.0,
help='Predict "" if no-answer probability exceeds this (default = 1.0).',
)
parser.add_argument(
"--out-image-dir", "-p", metavar="out_images", default=None, help="Save precision-recall curves to directory."
)
parser.add_argument("--verbose", "-v", action="store_true")
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def make_qid_to_has_ans(dataset):
qid_to_has_ans = {}
for article in dataset:
for p in article["paragraphs"]:
for qa in p["qas"]:
qid_to_has_ans[qa["id"]] = bool(qa["answers"]["text"])
return qid_to_has_ans
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()
def compute_exact(a_gold, a_pred):
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
def compute_f1(a_gold, a_pred):
gold_toks = get_tokens(a_gold)
pred_toks = get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def get_raw_scores(dataset, preds):
exact_scores = {}
f1_scores = {}
for article in dataset:
for p in article["paragraphs"]:
for qa in p["qas"]:
qid = qa["id"]
if qid not in preds:
print(f"Missing prediction for {qid}")
continue
a_pred = preds[qid]
gold_answers = [t for t in qa["answers"]["text"] if normalize_answer(t)]
if not gold_answers:
# For unanswerable questions, only correct answer is empty string
gold_answers = [""]
if a_pred != "":
exact_scores[qid] = 0
f1_scores[qid] = 0
else:
exact_scores[qid] = 1
f1_scores[qid] = 1
else:
# Take max over all gold answers
exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
return exact_scores, f1_scores
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
new_scores = {}
for qid, s in scores.items():
pred_na = na_probs[qid] > na_prob_thresh
if pred_na:
new_scores[qid] = float(not qid_to_has_ans[qid])
else:
new_scores[qid] = s
return new_scores
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
if not qid_list:
total = len(exact_scores)
return collections.OrderedDict(
[
("exact", 100.0 * sum(exact_scores.values()) / total),
("f1", 100.0 * sum(f1_scores.values()) / total),
("total", total),
]
)
else:
total = len(qid_list)
return collections.OrderedDict(
[
("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
("total", total),
]
)
def merge_eval(main_eval, new_eval, prefix):
for k in new_eval:
main_eval[f"{prefix}_{k}"] = new_eval[k]
def plot_pr_curve(precisions, recalls, out_image, title):
plt.step(recalls, precisions, color="b", alpha=0.2, where="post")
plt.fill_between(recalls, precisions, step="post", alpha=0.2, color="b")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.xlim([0.0, 1.05])
plt.ylim([0.0, 1.05])
plt.title(title)
plt.savefig(out_image)
plt.clf()
def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None):
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
true_pos = 0.0
cur_p = 1.0
cur_r = 0.0
precisions = [1.0]
recalls = [0.0]
avg_prec = 0.0
for i, qid in enumerate(qid_list):
if qid_to_has_ans[qid]:
true_pos += scores[qid]
cur_p = true_pos / float(i + 1)
cur_r = true_pos / float(num_true_pos)
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]:
# i.e., if we can put a threshold after this point
avg_prec += cur_p * (cur_r - recalls[-1])
precisions.append(cur_p)
recalls.append(cur_r)
if out_image:
plot_pr_curve(precisions, recalls, out_image, title)
return {"ap": 100.0 * avg_prec}
def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, out_image_dir):
if out_image_dir and not os.path.exists(out_image_dir):
os.makedirs(out_image_dir)
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
if num_true_pos == 0:
return
pr_exact = make_precision_recall_eval(
exact_raw,
na_probs,
num_true_pos,
qid_to_has_ans,
out_image=os.path.join(out_image_dir, "pr_exact.png"),
title="Precision-Recall curve for Exact Match score",
)
pr_f1 = make_precision_recall_eval(
f1_raw,
na_probs,
num_true_pos,
qid_to_has_ans,
out_image=os.path.join(out_image_dir, "pr_f1.png"),
title="Precision-Recall curve for F1 score",
)
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
pr_oracle = make_precision_recall_eval(
oracle_scores,
na_probs,
num_true_pos,
qid_to_has_ans,
out_image=os.path.join(out_image_dir, "pr_oracle.png"),
title="Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)",
)
merge_eval(main_eval, pr_exact, "pr_exact")
merge_eval(main_eval, pr_f1, "pr_f1")
merge_eval(main_eval, pr_oracle, "pr_oracle")
def histogram_na_prob(na_probs, qid_list, image_dir, name):
if not qid_list:
return
x = [na_probs[k] for k in qid_list]
weights = np.ones_like(x) / float(len(x))
plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
plt.xlabel("Model probability of no-answer")
plt.ylabel("Proportion of dataset")
plt.title(f"Histogram of no-answer probability: {name}")
plt.savefig(os.path.join(image_dir, f"na_prob_hist_{name}.png"))
plt.clf()
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = 0.0
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
for i, qid in enumerate(qid_list):
if qid not in scores:
continue
if qid_to_has_ans[qid]:
diff = scores[qid]
else:
if preds[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = na_probs[qid]
return 100.0 * best_score / len(scores), best_thresh
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
main_eval["best_exact"] = best_exact
main_eval["best_exact_thresh"] = exact_thresh
main_eval["best_f1"] = best_f1
main_eval["best_f1_thresh"] = f1_thresh
def main():
with open(OPTS.data_file) as f:
dataset_json = json.load(f)
dataset = dataset_json["data"]
with open(OPTS.pred_file) as f:
preds = json.load(f)
if OPTS.na_prob_file:
with open(OPTS.na_prob_file) as f:
na_probs = json.load(f)
else:
na_probs = {k: 0.0 for k in preds}
qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = get_raw_scores(dataset, preds)
exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh)
f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh)
out_eval = make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
merge_eval(out_eval, has_ans_eval, "HasAns")
if no_ans_qids:
no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
merge_eval(out_eval, no_ans_eval, "NoAns")
if OPTS.na_prob_file:
find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
if OPTS.na_prob_file and OPTS.out_image_dir:
run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, OPTS.out_image_dir)
histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, "hasAns")
histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, "noAns")
if OPTS.out_file:
with open(OPTS.out_file, "w") as f:
json.dump(out_eval, f)
else:
print(json.dumps(out_eval, indent=2))
if __name__ == "__main__":
OPTS = parse_args()
if OPTS.out_image_dir:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
main()

View File

@ -0,0 +1,139 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" SQuAD v2 metric. """
import datasets
from .evaluate import (
apply_no_ans_threshold,
find_all_best_thresh,
get_raw_scores,
make_eval_dict,
make_qid_to_has_ans,
merge_eval,
)
_CITATION = """\
@inproceedings{Rajpurkar2016SQuAD10,
title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text},
author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang},
booktitle={EMNLP},
year={2016}
}
"""
_DESCRIPTION = """
This metric wrap the official scoring script for version 2 of the Stanford Question
Answering Dataset (SQuAD).
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by
crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span,
from the corresponding reading passage, or the question might be unanswerable.
SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable questions
written adversarially by crowdworkers to look similar to answerable ones.
To do well on SQuAD2.0, systems must not only answer questions when possible, but also
determine when no answer is supported by the paragraph and abstain from answering.
"""
_KWARGS_DESCRIPTION = """
Computes SQuAD v2 scores (F1 and EM).
Args:
predictions: List of triple for question-answers to score with the following elements:
- the question-answer 'id' field as given in the references (see below)
- the text of the answer
- the probability that the question has no answer
references: List of question-answers dictionaries with the following key-values:
- 'id': id of the question-answer pair (see above),
- 'answers': a list of Dict {'text': text of the answer as a string}
no_answer_threshold: float
Probability threshold to decide that a question has no answer.
Returns:
'exact': Exact match (the normalized answer exactly match the gold answer)
'f1': The F-score of predicted tokens versus the gold answer
'total': Number of score considered
'HasAns_exact': Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': The F-score of predicted tokens versus the gold answer
'HasAns_total': Number of score considered
'NoAns_exact': Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': The F-score of predicted tokens versus the gold answer
'NoAns_total': Number of score considered
'best_exact': Best exact match (with varying threshold)
'best_exact_thresh': No-answer probability threshold associated to the best exact match
'best_f1': Best F1 (with varying threshold)
'best_f1_thresh': No-answer probability threshold associated to the best F1
Examples:
>>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22', 'no_answer_probability': 0.}]
>>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}]
>>> squad_v2_metric = datasets.load_metric("squad_v2")
>>> results = squad_v2_metric.compute(predictions=predictions, references=references)
>>> print(results)
{'exact': 100.0, 'f1': 100.0, 'total': 1, 'HasAns_exact': 100.0, 'HasAns_f1': 100.0, 'HasAns_total': 1, 'best_exact': 100.0, 'best_exact_thresh': 0.0, 'best_f1': 100.0, 'best_f1_thresh': 0.0}
"""
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class SquadV2Local(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": {
"id": datasets.Value("string"),
"prediction_text": datasets.Value("string"),
"no_answer_probability": datasets.Value("float32"),
},
"references": {
"id": datasets.Value("string"),
"answers": datasets.features.Sequence(
{"text": datasets.Value("string"), "answer_start": datasets.Value("int32")}
),
},
}
),
codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
)
def _compute(self, predictions, references, no_answer_threshold=1.0):
# no_answer_probabilities = dict((p["id"], p["no_answer_probability"]) for p in predictions)
dataset = [{"paragraphs": [{"qas": references}]}]
predictions = dict((p["id"], p["prediction_text"]) for p in predictions)
# qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
# has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
# no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = get_raw_scores(dataset, predictions)
out_eval = make_eval_dict(exact_raw, f1_raw)
#
#
# exact_thresh = apply_no_ans_threshold(exact_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold)
# f1_thresh = apply_no_ans_threshold(f1_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold)
# out_eval = make_eval_dict(exact_thresh, f1_thresh)
#
# if has_ans_qids:
# has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
# merge_eval(out_eval, has_ans_eval, "HasAns")
# if no_ans_qids:
# no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
# merge_eval(out_eval, no_ans_eval, "NoAns")
# find_all_best_thresh(out_eval, predictions, exact_raw, f1_raw, no_answer_probabilities, qid_to_has_ans)
return dict(out_eval)

View File

@ -0,0 +1,192 @@
import sys
sys.path.append('../data_process')
from typing import List, Optional, Tuple
from QAInput import SimpleQAInput as QAInput
#from QAInput import QAInputNoPrompt as QAInput
def preprocess_sqaud_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_sqaud_abstractive_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_boolq_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
question_column, context_column, answer_column = 'question', 'passage', 'answer'
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
targets = [str(ans) for ans in answers] #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
# print(inputs,targets)
return inputs, targets
def preprocess_boolq_batch_pretrain(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0].capitalize() if len(answer["text"]) > 0 else "" for answer in answers]
# print(inputs,targets)
return inputs, targets
def preprocess_narrativeqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = [exp['summary']['text'] for exp in examples['document']]
questions = [exp['text'] for exp in examples['question']]
answers = [ans[0]['text'] for ans in examples['answers']]
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = answers #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_narrativeqa_batch_pretrain(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_drop_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = examples['passage']
questions = examples['question']
answers = examples['answers']
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_race_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = examples['article']
questions = examples['question']
all_options = examples['options']
answers = examples['answer']
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}; D. {options[3]}' for options in all_options]
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3 }
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
return inputs, targets
def preprocess_newsqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
# inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_ropes_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
backgrounds = examples["background"]
situations = examples["situation"]
answers = examples[answer_column]
inputs = [QAInput.qg_input_extractive_qa(" ".join([background, situation]), question) for question, background, situation in zip(questions, backgrounds, situations)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_openbookqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples['question_stem']
all_options = examples['choices']
answers = examples['answerKey']
options_texts = [f"options: A. {options['text'][0]}; B. {options['text'][1]}; C. {options['text'][2]}; D. {options['text'][3]}" for options in all_options]
inputs = [QAInput.qg_input_multirc("", question, ops) for question, ops in zip(questions, options_texts)]
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
targets = [options['text'][ans_map[answer]] for options, answer in zip(all_options, answers)]
return inputs, targets
def preprocess_social_iqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = examples['article']
questions = examples['question']
all_options = examples['options']
answers = examples['answer']
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
ans_map = {'A': 0, 'B': 1, 'C': 2,}
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
return inputs, targets
def preprocess_dream_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = [" ".join(dialogue) for dialogue in examples['dialogue']]
questions = examples['question']
all_options = examples['choice']
answers = examples['answer']
answer_idxs = [options.index(answer) for answer, options in zip(answers, all_options)]
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
targets = answers
return inputs, targets

View File

@ -0,0 +1,21 @@
#ProQA
#[paper]https://arxiv.org/abs/2205.04040
class StructuralQAInput:
@classmethod
def qg_input(cls, context, question, options=None):
question = question
source_text = f'[QUESTION] {question}. [CONTEXT] {context} the answer is: '
return source_text
class SimpleQAInput:
@classmethod
def qg_input(cls, context, question, options=None):
question = question
source_text = f'{question} \\n {context}'
return source_text

View File

@ -0,0 +1,192 @@
import sys
sys.path.append('../data_process')
from typing import List, Optional, Tuple
from QAInput import StructuralQAInput as QAInput
#from QAInput import QAInputNoPrompt as QAInput
def preprocess_sqaud_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_sqaud_abstractive_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_boolq_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
question_column, context_column, answer_column = 'question', 'passage', 'answer'
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
targets = [str(ans) for ans in answers] #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
# print(inputs,targets)
return inputs, targets
def preprocess_boolq_batch_pretrain(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0].capitalize() if len(answer["text"]) > 0 else "" for answer in answers]
# print(inputs,targets)
return inputs, targets
def preprocess_narrativeqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = [exp['summary']['text'] for exp in examples['document']]
questions = [exp['text'] for exp in examples['question']]
answers = [ans[0]['text'] for ans in examples['answers']]
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = answers #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_narrativeqa_batch_pretrain(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_drop_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = examples['passage']
questions = examples['question']
answers = examples['answers']
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_race_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = examples['article']
questions = examples['question']
all_options = examples['options']
answers = examples['answer']
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}; D. {options[3]}' for options in all_options]
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3 }
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
return inputs, targets
def preprocess_newsqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
# inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_ropes_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
backgrounds = examples["background"]
situations = examples["situation"]
answers = examples[answer_column]
inputs = [QAInput.qg_input_extractive_qa(" ".join([background, situation]), question) for question, background, situation in zip(questions, backgrounds, situations)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_openbookqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples['question_stem']
all_options = examples['choices']
answers = examples['answerKey']
options_texts = [f"options: A. {options['text'][0]}; B. {options['text'][1]}; C. {options['text'][2]}; D. {options['text'][3]}" for options in all_options]
inputs = [QAInput.qg_input_multirc("", question, ops) for question, ops in zip(questions, options_texts)]
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
targets = [options['text'][ans_map[answer]] for options, answer in zip(all_options, answers)]
return inputs, targets
def preprocess_social_iqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = examples['article']
questions = examples['question']
all_options = examples['options']
answers = examples['answer']
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
ans_map = {'A': 0, 'B': 1, 'C': 2,}
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
return inputs, targets
def preprocess_dream_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = [" ".join(dialogue) for dialogue in examples['dialogue']]
questions = examples['question']
all_options = examples['choice']
answers = examples['answer']
answer_idxs = [options.index(answer) for answer, options in zip(answers, all_options)]
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
targets = answers
return inputs, targets

View File

@ -0,0 +1,801 @@
#!/usr/bin/env python
# coding=utf-8
import logging
import os
import torch
import copy,random
import sys
import json
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple
from sklearn.cluster import KMeans
# from data import QAData
sys.path.append("../")
from models.metadeca import T5ForConditionalGeneration as PromptT5
from metrics import compute_metrics
from downstreamdeca.simple_processors import *
from downstreamdeca.l2ptrainer import QuestionAnsweringTrainer
import datasets
import numpy as np
from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets
import os
from functools import partial
os.environ["WANDB_DISABLED"] = "true"
import transformers
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
M2M100Tokenizer,
MBart50Tokenizer,
EvalPrediction,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
set_seed,
)
from pathlib import Path
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
format2dataset = {
'extractive':['squad','extractive','newsqa','quoref'], #,'ropes'
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
}
seed_datasets = ['race','narrativeqa','squad']
dataset2format= {}
task2id = {}
for k,vs in format2dataset.items():
for v in vs:
dataset2format[v] = k
task2id = {"wikisql":0,"squad":1,"srl":2,"sst":3,"woz.en":4,"iwslt.en.de":5,"cnn_dailymail":6,"multinli.in.out":7,"zre":8}
task2format={"wikisql":1,"squad":0,"srl":0,"sst":2,"woz.en":1,"iwslt.en.de":0,"cnn_dailymail":1,"multinli.in.out":2,"zre":0}
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
fix_t5: Optional[bool] = field(
default=False,
metadata={"help": "whether to fix the main parameters of t5"}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
validation_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
"a jsonlines file."
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
after_train: Optional[str] = field(default=None)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
},
)
max_seq_length: int = field(
default=384,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
"which is used during ``evaluate`` and ``predict``."
},
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
)
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
"needs to be the target language token.(Usually it is the target language token)"
},
)
do_lowercase: bool = field(
default=False,
metadata={
'help':'Whether to process input into lowercase'
}
)
append_another_bos: bool = field(
default=False,
metadata={
'help':'Whether to add another bos'
}
)
context_column: Optional[str] = field(
default="context",
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
)
question_column: Optional[str] = field(
default="question",
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
answer_column: Optional[str] = field(
default="answers",
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
max_task_num: Optional[int] = field(
default=30,
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
qa_task_type_num: Optional[int] = field(
default=4,
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
prompt_number: Optional[int] = field(
default=100,
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
add_task_prompt: Optional[bool] = field(
default=False,
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
reload_from_trained_prompt: Optional[bool] = field(
default=False,
metadata={"help": "whether to reload prompt from trained prompt"},
)
trained_prompt_path:Optional[str] = field(
default = None,
metadata={
'help':'the path storing trained prompt embedding'
}
)
load_from_format_task_id: Optional[bool] = field(
default=False,
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
# elif self.source_lang is None or self.target_lang is None:
# raise ValueError("Need to specify the source language and the target language.")
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
# assert extension == "json", "`train_file` should be a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
# assert extension == "json", "`validation_file` should be a json file."
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
# +
def main():
def preprocess_validation_function(examples):
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
model_inputs["example_id"] = []
for i in range(len(model_inputs["input_ids"])):
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = i #sample_mapping[i]
model_inputs["example_id"].append(examples["id"][sample_index])
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def save_prompt_embedding(model,path):
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
prompt_path = os.path.join(path,'prompt_embedding_info')
torch.save(save_prompt_info,prompt_path)
logger.info(f'Saving prompt embedding information to {prompt_path}')
def preprocess_function(examples):
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def save_load_diverse_sample(model,trainset):
with torch.no_grad():
device = 'cuda:0'
search_upbound = len(trainset)//4
query_idxs = [None]*30
keys = model.encoder.domain_keys
for idx,item in enumerate(trainset.select(range(search_upbound))):
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
return_dict=True)
result = torch.matmul(query,keys.t())
result = torch.topk(result,5).indices[0].cpu().numpy().tolist()
key_sel = None
for key_idx in result:
if query_idxs[key_idx] is None or len(query_idxs[key_idx])<3:
key_sel = key_idx
break
if key_sel is not None:
if query_idxs[key_sel] is None:
query_idxs[key_sel] = [idx]
else:
query_idxs[key_sel].append(idx)
total_idxs = []
for item in query_idxs:
try:
total_idxs.extend(item[:3])
except:
total_idxs.extend(random.sample(list(range(search_upbound,len(trainset))),3))
total_idxs = list(set(total_idxs))
total_idxs = random.sample(total_idxs,50)
sub_set = trainset.select(total_idxs)
features = []
for idx,item in enumerate(sub_set):
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
return_dict=True)
features.append(query.detach().cpu().numpy())
return sub_set,features
def dataset_name_to_func(dataset_name):
mapping = {
'squad': preprocess_sqaud_batch,
'squad_v2': preprocess_sqaud_batch,
'boolq': preprocess_boolq_batch,
'narrativeqa': preprocess_narrativeqa_batch,
'race': preprocess_race_batch,
'newsqa': preprocess_newsqa_batch,
'quoref': preprocess_sqaud_batch,
'ropes': preprocess_ropes_batch,
'drop': preprocess_drop_batch,
'nqopen': preprocess_sqaud_abstractive_batch,
# 'multirc': preprocess_boolq_batch,
'boolq_np': preprocess_boolq_batch,
'openbookqa': preprocess_openbookqa_batch,
'mctest': preprocess_race_batch,
'social_iqa': preprocess_social_iqa_batch,
'dream': preprocess_dream_batch,
}
return mapping[dataset_name]
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
dataset_name_to_metric = {
'squad': 'metric/squad_v1_local/squad_v1_local.py',
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
'boolq': 'metric/accuracy.py',
'narrativeqa': 'metric/rouge_local/rouge_metric.py',
'race': 'metric/accuracy.py',
'quoref': 'metric/squad_v1_local/squad_v1_local.py',
'ropes': 'metric/squad_v1_local/squad_v1_local.py',
'drop': 'metric/squad_v1_local/squad_v1_local.py',
'nqopen': 'metric/squad_v1_local/squad_v1_local.py',
'boolq_np': 'metric/accuracy.py',
'openbookqa': 'metric/accuracy.py',
'mctest': 'metric/accuracy.py',
'social_iqa': 'metric/accuracy.py',
'dream': 'metric/accuracy.py',
}
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
if data_args.source_prefix is None and model_args.model_name_or_path in [
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
]:
logger.warning(
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
"`--source_prefix 'translate English to German: ' `"
)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(training_args.seed)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
# '[OPTIONS]'])
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
'[OPTIONS]']}
tokenizer.add_tokens(tokens_to_add)
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
added_tokens = tokenizer.get_added_vocab()
logger.info('Added tokens: {}'.format(added_tokens))
model = PromptT5.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
# task_num = data_args.max_task_num,
# prompt_num = data_args.prompt_number,
# format_num = data_args.qa_task_type_num,
# add_task_prompt = False
)
model.resize_token_embeddings(len(tokenizer))
#reload format specific task-prompt for newly involved task
#format_prompts###task_promptsf
data_args.reload_from_trained_prompt = False#@
data_args.load_from_format_task_id = False#@
### before pretrain come !!!!!!
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
task_start_id = data_args.prompt_number * len(format2dataset.keys())
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
# print(dataset2format[data_args.dataset_name])
# print(data_args.dataset_name)
elif data_args.reload_from_trained_prompt:
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
prompt_info = torch.load(data_args.trained_prompt_path)
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
task_start_id = data_args.prompt_number * len(format2dataset.keys())
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
format_id = format2id[dataset2format[data_args.dataset_name]]
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
logger.info(
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
# Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
if training_args.local_rank == -1 or training_args.no_cuda:
device = torch.device("cuda")
n_gpu = torch.cuda.device_count()
# Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length
padding = "max_length" if data_args.pad_to_max_length else False
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
logger.warning(
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
)
question_column = data_args.question_column
context_column = data_args.context_column
answer_column = data_args.answer_column
# import random
if data_args.max_source_length > tokenizer.model_max_length:
logger.warning(
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
# Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
if data_args.pad_to_max_length:
data_collator = default_data_collator
else:
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
)
#start
train_dataloaders = {}
eval_dataloaders = {}
replay_dataloaders = {}
all_replay = None
for ds_name in ["squad","wikisql","sst","srl","woz.en"]:
if True:
cur_dataset = ds_name
train_dataloaders[ds_name] = load_from_disk("./oursdeca/{}-train.hf".format(cur_dataset)).select(range(200))
eval_dataloaders[ds_name] = (load_from_disk("./oursdeca/{}-eval.hf".format(cur_dataset)).select(range(200)),load_from_disk("./oursdeca/{}-evalex.hf".format(cur_dataset)).select(range(200)))
for ds_name in ["cnn_dailymail","multinli.in.out","zre"]:
eval_dataloaders[ds_name] = (load_from_disk("./oursdeca/{}-eval.hf".format(ds_name)),load_from_disk("./oursdeca/{}-evalex.hf".format(ds_name)))
pre_tasks = []
pre_general = []
pre_test = []
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
task_sequence = ["squad","wikisql","sst","srl","woz.en"]
# task_sequence = ["woz.en","srl","sst","wikisql","squad"]
need_to_do_dss = []
fileout = open("diana_log.txt",'w')
all_replay = None
all_features = []
all_ids = []
cluster_num=0
for cur_dataset in task_sequence:
cluster_num+=5
pre_tasks.append(cur_dataset)
if cur_dataset==task_sequence[-1]:
pre_tasks.extend(["cnn_dailymail","multinli.in.out","zre"])
data_args.dataset_name = cur_dataset
logger.info("current_dataset:"+cur_dataset)
training_args.do_train = True
training_args.to_eval = False
metric = load_metric("metric/squad_v1_local/squad_v1_local.py")
if all_replay is not None:
fused = datasets.concatenate_datasets([all_replay,train_dataloaders[cur_dataset]])
else:
fused = train_dataloaders[cur_dataset]
training_args.num_train_epochs = 5
model.encoder.reset_train_count()
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=fused,
eval_dataset=None,
eval_examples=None,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
train_result = trainer.train()
if training_args.local_rank<=0:
save_set,features = save_load_diverse_sample(model,train_dataloaders[cur_dataset])
if all_replay is None:
all_replay = save_set
else:
all_replay = datasets.concatenate_datasets([all_replay,save_set])
if all_features==[]:
all_features=features
else:
all_features.extend(features)
np.save("./all_features.npy",np.array(all_features))
all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset))
if training_args.local_rank!=-1:
torch.distributed.barrier()
all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset))
all_ids.extend([task2id[cur_dataset]]*50)
all_features=np.load("./all_features.npy").tolist()
model.encoder.add_negs(all_ids,all_features)
for pre_dataset in pre_tasks:
data_args.dataset_name = pre_dataset
metric = load_metric("metric/squad_v1_local/squad_v1_local.py")
eval_dataset,eval_examples = eval_dataloaders[pre_dataset]
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=None,
eval_dataset=eval_dataset,
eval_examples=eval_examples,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
torch.cuda.empty_cache()
logger.info("*** Evaluate:{} ***".format(data_args.dataset_name))
max_length, num_beams, ignore_keys_for_eval = None, None, None
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.local_rank<=0:
try:
print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout)
print(metrics,file=fileout)
except:
pass
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
if len(languages) > 0:
kwargs["language"] = languages
return None
# -
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,53 @@
function fulldata_hfdata() {
SAVING_PATH=$1
PRETRAIN_MODEL_PATH=$2
TRAIN_BATCH_SIZE=$3
EVAL_BATCH_SIZE=$4
mkdir -p ${SAVING_PATH}
python -m torch.distributed.launch --nproc_per_node=4 diana.py \
--model_name_or_path ${PRETRAIN_MODEL_PATH} \
--output_dir ${SAVING_PATH} \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size ${TRAIN_BATCH_SIZE} \
--per_device_eval_batch_size ${EVAL_BATCH_SIZE} \
--overwrite_output_dir \
--gradient_accumulation_steps 2 \
--num_train_epochs 5 \
--warmup_ratio 0.1 \
--logging_steps 5 \
--learning_rate 1e-4 \
--predict_with_generate \
--num_beams 4 \
--save_strategy no \
--evaluation_strategy no \
--weight_decay 1e-2 \
--max_source_length 512 \
--label_smoothing_factor 0.1 \
--do_lowercase True \
--load_best_model_at_end True \
--greater_is_better True \
--save_total_limit 10 \
--ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log
}
SAVING_PATH=$1
TRAIN_BATCH_SIZE=$2
EVAL_BATCH_SIZE=$3
MODE=try
SAVING_PATH=${SAVING_PATH}/lifelong/${MODE}
fulldata_hfdata ${SAVING_PATH} t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE}

View File

@ -0,0 +1,6 @@
after_train_ squad _test_ squad
OrderedDict([('eval_em', 88.0), ('eval_nf1', 91.79166666666667), ('eval_nem', 89.5), ('eval_samples', 200)])
after_train_ wikisql _test_ squad
OrderedDict([('eval_em', 83.0), ('eval_nf1', 87.73095238095239), ('eval_nem', 84.5), ('eval_samples', 200)])
after_train_ wikisql _test_ wikisql
OrderedDict([('eval_lfem', 30.5), ('eval_em', 17.5), ('eval_nf1', 84.06011500628662), ('eval_nem', 20.0), ('eval_samples', 200)])

View File

@ -0,0 +1,775 @@
#!/usr/bin/env python
# coding=utf-8
import logging
import os
import torch
import copy,random
import sys
import json
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple
from sklearn.cluster import KMeans
# from data import QAData
sys.path.append("../")
from models.metadecanometa import T5ForConditionalGeneration as PromptT5
from metrics import compute_metrics
from downstreamdeca.simple_processors import *
from downstreamdeca.l2ptrainer import QuestionAnsweringTrainer
import datasets
import numpy as np
from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets
import os
from functools import partial
os.environ["WANDB_DISABLED"] = "true"
import transformers
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
M2M100Tokenizer,
MBart50Tokenizer,
EvalPrediction,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
set_seed,
)
from pathlib import Path
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
format2dataset = {
'extractive':['squad','extractive','newsqa','quoref'], #,'ropes'
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
}
seed_datasets = ['race','narrativeqa','squad']
dataset2format= {}
task2id = {}
for k,vs in format2dataset.items():
for v in vs:
dataset2format[v] = k
task2id = {"wikisql":0,"squad":1,"srl":2,"sst":3,"woz.en":4,"iwslt.en.de":5,"cnn_dailymail":6,"multinli.in.out":7,"zre":8}
task2format={"wikisql":1,"squad":0,"srl":0,"sst":2,"woz.en":1,"iwslt.en.de":0,"cnn_dailymail":1,"multinli.in.out":2,"zre":0}
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
fix_t5: Optional[bool] = field(
default=False,
metadata={"help": "whether to fix the main parameters of t5"}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
validation_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
"a jsonlines file."
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
after_train: Optional[str] = field(default=None)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
},
)
max_seq_length: int = field(
default=384,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
"which is used during ``evaluate`` and ``predict``."
},
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
)
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
"needs to be the target language token.(Usually it is the target language token)"
},
)
do_lowercase: bool = field(
default=False,
metadata={
'help':'Whether to process input into lowercase'
}
)
append_another_bos: bool = field(
default=False,
metadata={
'help':'Whether to add another bos'
}
)
context_column: Optional[str] = field(
default="context",
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
)
question_column: Optional[str] = field(
default="question",
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
answer_column: Optional[str] = field(
default="answers",
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
max_task_num: Optional[int] = field(
default=30,
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
qa_task_type_num: Optional[int] = field(
default=4,
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
prompt_number: Optional[int] = field(
default=100,
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
add_task_prompt: Optional[bool] = field(
default=False,
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
reload_from_trained_prompt: Optional[bool] = field(
default=False,
metadata={"help": "whether to reload prompt from trained prompt"},
)
trained_prompt_path:Optional[str] = field(
default = None,
metadata={
'help':'the path storing trained prompt embedding'
}
)
load_from_format_task_id: Optional[bool] = field(
default=False,
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
# elif self.source_lang is None or self.target_lang is None:
# raise ValueError("Need to specify the source language and the target language.")
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
# assert extension == "json", "`train_file` should be a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
# assert extension == "json", "`validation_file` should be a json file."
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
# +
def main():
def preprocess_validation_function(examples):
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
model_inputs["example_id"] = []
for i in range(len(model_inputs["input_ids"])):
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = i #sample_mapping[i]
model_inputs["example_id"].append(examples["id"][sample_index])
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def save_prompt_embedding(model,path):
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
prompt_path = os.path.join(path,'prompt_embedding_info')
torch.save(save_prompt_info,prompt_path)
logger.info(f'Saving prompt embedding information to {prompt_path}')
def preprocess_function(examples):
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def save_load_diverse_sample(model,trainset):
with torch.no_grad():
device = 'cuda:0'
sub_set = trainset.select(range(50))
features = []
for idx,item in enumerate(sub_set):
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
return_dict=True)
features.append(query.detach().cpu().numpy())
return sub_set,features
def dataset_name_to_func(dataset_name):
mapping = {
'squad': preprocess_sqaud_batch,
'squad_v2': preprocess_sqaud_batch,
'boolq': preprocess_boolq_batch,
'narrativeqa': preprocess_narrativeqa_batch,
'race': preprocess_race_batch,
'newsqa': preprocess_newsqa_batch,
'quoref': preprocess_sqaud_batch,
'ropes': preprocess_ropes_batch,
'drop': preprocess_drop_batch,
'nqopen': preprocess_sqaud_abstractive_batch,
# 'multirc': preprocess_boolq_batch,
'boolq_np': preprocess_boolq_batch,
'openbookqa': preprocess_openbookqa_batch,
'mctest': preprocess_race_batch,
'social_iqa': preprocess_social_iqa_batch,
'dream': preprocess_dream_batch,
}
return mapping[dataset_name]
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
dataset_name_to_metric = {
'squad': 'metric/squad_v1_local/squad_v1_local.py',
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
'boolq': 'metric/accuracy.py',
'narrativeqa': 'metric/rouge_local/rouge_metric.py',
'race': 'metric/accuracy.py',
'quoref': 'metric/squad_v1_local/squad_v1_local.py',
'ropes': 'metric/squad_v1_local/squad_v1_local.py',
'drop': 'metric/squad_v1_local/squad_v1_local.py',
'nqopen': 'metric/squad_v1_local/squad_v1_local.py',
'boolq_np': 'metric/accuracy.py',
'openbookqa': 'metric/accuracy.py',
'mctest': 'metric/accuracy.py',
'social_iqa': 'metric/accuracy.py',
'dream': 'metric/accuracy.py',
}
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
if data_args.source_prefix is None and model_args.model_name_or_path in [
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
]:
logger.warning(
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
"`--source_prefix 'translate English to German: ' `"
)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(training_args.seed)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
# '[OPTIONS]'])
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
'[OPTIONS]']}
tokenizer.add_tokens(tokens_to_add)
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
added_tokens = tokenizer.get_added_vocab()
logger.info('Added tokens: {}'.format(added_tokens))
model = PromptT5.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
# task_num = data_args.max_task_num,
# prompt_num = data_args.prompt_number,
# format_num = data_args.qa_task_type_num,
# add_task_prompt = False
)
model.resize_token_embeddings(len(tokenizer))
#reload format specific task-prompt for newly involved task
#format_prompts###task_promptsf
data_args.reload_from_trained_prompt = False#@
data_args.load_from_format_task_id = False#@
### before pretrain come !!!!!!
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
task_start_id = data_args.prompt_number * len(format2dataset.keys())
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
# print(dataset2format[data_args.dataset_name])
# print(data_args.dataset_name)
elif data_args.reload_from_trained_prompt:
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
prompt_info = torch.load(data_args.trained_prompt_path)
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
task_start_id = data_args.prompt_number * len(format2dataset.keys())
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
format_id = format2id[dataset2format[data_args.dataset_name]]
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
logger.info(
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
# Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
if training_args.local_rank == -1 or training_args.no_cuda:
device = torch.device("cuda")
n_gpu = torch.cuda.device_count()
# Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length
padding = "max_length" if data_args.pad_to_max_length else False
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
logger.warning(
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
)
question_column = data_args.question_column
context_column = data_args.context_column
answer_column = data_args.answer_column
# import random
if data_args.max_source_length > tokenizer.model_max_length:
logger.warning(
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
# Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
if data_args.pad_to_max_length:
data_collator = default_data_collator
else:
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
)
#start
train_dataloaders = {}
eval_dataloaders = {}
replay_dataloaders = {}
all_replay = None
for ds_name in ["squad","wikisql","sst","srl","woz.en"]:
if True:
cur_dataset = ds_name
train_dataloaders[ds_name] = load_from_disk("./oursdecanometa/{}-train.hf".format(cur_dataset)).select(range(200))
eval_dataloaders[ds_name] = (load_from_disk("./oursdecanometa/{}-eval.hf".format(cur_dataset)).select(range(200)),load_from_disk("./oursdecanometa/{}-evalex.hf".format(cur_dataset)).select(range(200)))
for ds_name in ["cnn_dailymail","multinli.in.out","zre"]:
eval_dataloaders[ds_name] = (load_from_disk("./oursdecanometa/{}-eval.hf".format(ds_name)),load_from_disk("./oursdecanometa/{}-evalex.hf".format(ds_name)))
pre_tasks = []
pre_general = []
pre_test = []
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
task_sequence = ["squad","wikisql","sst","srl","woz.en"]
# task_sequence = ["woz.en","srl","sst","wikisql","squad"]
need_to_do_dss = []
fileout = open("diana_log.txt",'w')
all_replay = None
all_features = []
all_ids = []
cluster_num=0
for cur_dataset in task_sequence:
cluster_num+=5
pre_tasks.append(cur_dataset)
if cur_dataset==task_sequence[-1]:
pre_tasks.extend(["cnn_dailymail","multinli.in.out","zre"])
data_args.dataset_name = cur_dataset
logger.info("current_dataset:"+cur_dataset)
training_args.do_train = True
training_args.to_eval = False
metric = load_metric("metric/squad_v1_local/squad_v1_local.py")
if all_replay is not None:
fused = datasets.concatenate_datasets([all_replay,train_dataloaders[cur_dataset]])
else:
fused = train_dataloaders[cur_dataset]
training_args.num_train_epochs = 5
model.encoder.reset_train_count()
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=fused,
eval_dataset=None,
eval_examples=None,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
train_result = trainer.train()
if training_args.local_rank<=0:
save_set,features = save_load_diverse_sample(model,train_dataloaders[cur_dataset])
if all_replay is None:
all_replay = save_set
else:
all_replay = datasets.concatenate_datasets([all_replay,save_set])
if all_features==[]:
all_features=features
else:
all_features.extend(features)
np.save("./all_features.npy",np.array(all_features))
all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset))
if training_args.local_rank!=-1:
torch.distributed.barrier()
all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset))
all_ids.extend([task2id[cur_dataset]]*50)
all_features=np.load("./all_features.npy").tolist()
model.encoder.add_negs(all_ids,all_features)
for pre_dataset in pre_tasks:
data_args.dataset_name = pre_dataset
metric = load_metric("metric/squad_v1_local/squad_v1_local.py")
eval_dataset,eval_examples = eval_dataloaders[pre_dataset]
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=None,
eval_dataset=eval_dataset,
eval_examples=eval_examples,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
torch.cuda.empty_cache()
logger.info("*** Evaluate:{} ***".format(data_args.dataset_name))
max_length, num_beams, ignore_keys_for_eval = None, None, None
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.local_rank<=0:
try:
print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout)
print(metrics,file=fileout)
except:
pass
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
if len(languages) > 0:
kwargs["language"] = languages
return None
# -
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,53 @@
function fulldata_hfdata() {
SAVING_PATH=$1
PRETRAIN_MODEL_PATH=$2
TRAIN_BATCH_SIZE=$3
EVAL_BATCH_SIZE=$4
mkdir -p ${SAVING_PATH}
python -m torch.distributed.launch --nproc_per_node=4 dianawometa.py \
--model_name_or_path ${PRETRAIN_MODEL_PATH} \
--output_dir ${SAVING_PATH} \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size ${TRAIN_BATCH_SIZE} \
--per_device_eval_batch_size ${EVAL_BATCH_SIZE} \
--overwrite_output_dir \
--gradient_accumulation_steps 2 \
--num_train_epochs 5 \
--warmup_ratio 0.1 \
--logging_steps 5 \
--learning_rate 1e-4 \
--predict_with_generate \
--num_beams 4 \
--save_strategy no \
--evaluation_strategy no \
--weight_decay 1e-2 \
--max_source_length 512 \
--label_smoothing_factor 0.1 \
--do_lowercase True \
--load_best_model_at_end True \
--greater_is_better True \
--save_total_limit 10 \
--ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log
}
SAVING_PATH=$1
TRAIN_BATCH_SIZE=$2
EVAL_BATCH_SIZE=$3
MODE=try
SAVING_PATH=${SAVING_PATH}/lifelong/${MODE}
fulldata_hfdata ${SAVING_PATH} t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE}

View File

@ -0,0 +1,801 @@
#!/usr/bin/env python
# coding=utf-8
import logging
import os
import torch
import copy,random
import sys
import json
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple
from sklearn.cluster import KMeans
# from data import QAData
sys.path.append("../")
from models.metadecanotask import T5ForConditionalGeneration as PromptT5
from metrics import compute_metrics
from downstreamdeca.simple_processors import *
from downstreamdeca.l2ptrainer import QuestionAnsweringTrainer
import datasets
import numpy as np
from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets
import os
from functools import partial
os.environ["WANDB_DISABLED"] = "true"
import transformers
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
M2M100Tokenizer,
MBart50Tokenizer,
EvalPrediction,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
set_seed,
)
from pathlib import Path
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
format2dataset = {
'extractive':['squad','extractive','newsqa','quoref'], #,'ropes'
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
}
seed_datasets = ['race','narrativeqa','squad']
dataset2format= {}
task2id = {}
for k,vs in format2dataset.items():
for v in vs:
dataset2format[v] = k
task2id = {"wikisql":0,"squad":1,"srl":2,"sst":3,"woz.en":4,"iwslt.en.de":5,"cnn_dailymail":6,"multinli.in.out":7,"zre":8}
task2format={"wikisql":1,"squad":0,"srl":0,"sst":2,"woz.en":1,"iwslt.en.de":0,"cnn_dailymail":1,"multinli.in.out":2,"zre":0}
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
fix_t5: Optional[bool] = field(
default=False,
metadata={"help": "whether to fix the main parameters of t5"}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
validation_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
"a jsonlines file."
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
after_train: Optional[str] = field(default=None)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
},
)
max_seq_length: int = field(
default=384,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
"which is used during ``evaluate`` and ``predict``."
},
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
)
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
"needs to be the target language token.(Usually it is the target language token)"
},
)
do_lowercase: bool = field(
default=False,
metadata={
'help':'Whether to process input into lowercase'
}
)
append_another_bos: bool = field(
default=False,
metadata={
'help':'Whether to add another bos'
}
)
context_column: Optional[str] = field(
default="context",
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
)
question_column: Optional[str] = field(
default="question",
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
answer_column: Optional[str] = field(
default="answers",
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
max_task_num: Optional[int] = field(
default=30,
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
qa_task_type_num: Optional[int] = field(
default=4,
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
)
prompt_number: Optional[int] = field(
default=100,
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
add_task_prompt: Optional[bool] = field(
default=False,
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
)
reload_from_trained_prompt: Optional[bool] = field(
default=False,
metadata={"help": "whether to reload prompt from trained prompt"},
)
trained_prompt_path:Optional[str] = field(
default = None,
metadata={
'help':'the path storing trained prompt embedding'
}
)
load_from_format_task_id: Optional[bool] = field(
default=False,
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
# elif self.source_lang is None or self.target_lang is None:
# raise ValueError("Need to specify the source language and the target language.")
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
# assert extension == "json", "`train_file` should be a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
# assert extension == "json", "`validation_file` should be a json file."
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
# +
def main():
def preprocess_validation_function(examples):
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
model_inputs["example_id"] = []
for i in range(len(model_inputs["input_ids"])):
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = i #sample_mapping[i]
model_inputs["example_id"].append(examples["id"][sample_index])
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def save_prompt_embedding(model,path):
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
prompt_path = os.path.join(path,'prompt_embedding_info')
torch.save(save_prompt_info,prompt_path)
logger.info(f'Saving prompt embedding information to {prompt_path}')
def preprocess_function(examples):
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def save_load_diverse_sample(model,trainset):
with torch.no_grad():
device = 'cuda:0'
search_upbound = len(trainset)//4
query_idxs = [None]*30
keys = model.encoder.domain_keys
for idx,item in enumerate(trainset.select(range(search_upbound))):
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
return_dict=True)
result = torch.matmul(query,keys.t())
result = torch.topk(result,5).indices[0].cpu().numpy().tolist()
key_sel = None
for key_idx in result:
if query_idxs[key_idx] is None or len(query_idxs[key_idx])<3:
key_sel = key_idx
break
if key_sel is not None:
if query_idxs[key_sel] is None:
query_idxs[key_sel] = [idx]
else:
query_idxs[key_sel].append(idx)
total_idxs = []
for item in query_idxs:
try:
total_idxs.extend(item[:3])
except:
total_idxs.extend(random.sample(list(range(search_upbound,len(trainset))),3))
total_idxs = list(set(total_idxs))
total_idxs = random.sample(total_idxs,50)
sub_set = trainset.select(total_idxs)
features = []
for idx,item in enumerate(sub_set):
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
return_dict=True)
features.append(query.detach().cpu().numpy())
return sub_set,features
def dataset_name_to_func(dataset_name):
mapping = {
'squad': preprocess_sqaud_batch,
'squad_v2': preprocess_sqaud_batch,
'boolq': preprocess_boolq_batch,
'narrativeqa': preprocess_narrativeqa_batch,
'race': preprocess_race_batch,
'newsqa': preprocess_newsqa_batch,
'quoref': preprocess_sqaud_batch,
'ropes': preprocess_ropes_batch,
'drop': preprocess_drop_batch,
'nqopen': preprocess_sqaud_abstractive_batch,
# 'multirc': preprocess_boolq_batch,
'boolq_np': preprocess_boolq_batch,
'openbookqa': preprocess_openbookqa_batch,
'mctest': preprocess_race_batch,
'social_iqa': preprocess_social_iqa_batch,
'dream': preprocess_dream_batch,
}
return mapping[dataset_name]
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
dataset_name_to_metric = {
'squad': 'metric/squad_v1_local/squad_v1_local.py',
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
'boolq': 'metric/accuracy.py',
'narrativeqa': 'metric/rouge_local/rouge_metric.py',
'race': 'metric/accuracy.py',
'quoref': 'metric/squad_v1_local/squad_v1_local.py',
'ropes': 'metric/squad_v1_local/squad_v1_local.py',
'drop': 'metric/squad_v1_local/squad_v1_local.py',
'nqopen': 'metric/squad_v1_local/squad_v1_local.py',
'boolq_np': 'metric/accuracy.py',
'openbookqa': 'metric/accuracy.py',
'mctest': 'metric/accuracy.py',
'social_iqa': 'metric/accuracy.py',
'dream': 'metric/accuracy.py',
}
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
if data_args.source_prefix is None and model_args.model_name_or_path in [
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
]:
logger.warning(
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
"`--source_prefix 'translate English to German: ' `"
)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(training_args.seed)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
# '[OPTIONS]'])
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
'[OPTIONS]']}
tokenizer.add_tokens(tokens_to_add)
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
added_tokens = tokenizer.get_added_vocab()
logger.info('Added tokens: {}'.format(added_tokens))
model = PromptT5.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
# task_num = data_args.max_task_num,
# prompt_num = data_args.prompt_number,
# format_num = data_args.qa_task_type_num,
# add_task_prompt = False
)
model.resize_token_embeddings(len(tokenizer))
#reload format specific task-prompt for newly involved task
#format_prompts###task_promptsf
data_args.reload_from_trained_prompt = False#@
data_args.load_from_format_task_id = False#@
### before pretrain come !!!!!!
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
task_start_id = data_args.prompt_number * len(format2dataset.keys())
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
# print(dataset2format[data_args.dataset_name])
# print(data_args.dataset_name)
elif data_args.reload_from_trained_prompt:
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
prompt_info = torch.load(data_args.trained_prompt_path)
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
task_start_id = data_args.prompt_number * len(format2dataset.keys())
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
format_id = format2id[dataset2format[data_args.dataset_name]]
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
logger.info(
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
# Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
if training_args.local_rank == -1 or training_args.no_cuda:
device = torch.device("cuda")
n_gpu = torch.cuda.device_count()
# Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length
padding = "max_length" if data_args.pad_to_max_length else False
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
logger.warning(
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
)
question_column = data_args.question_column
context_column = data_args.context_column
answer_column = data_args.answer_column
# import random
if data_args.max_source_length > tokenizer.model_max_length:
logger.warning(
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
# Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
if data_args.pad_to_max_length:
data_collator = default_data_collator
else:
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
)
#start
train_dataloaders = {}
eval_dataloaders = {}
replay_dataloaders = {}
all_replay = None
for ds_name in ["squad","wikisql","sst","srl","woz.en"]:
if True:
cur_dataset = ds_name
train_dataloaders[ds_name] = load_from_disk("./oursdecanotask/{}-train.hf".format(cur_dataset)).select(range(200))
eval_dataloaders[ds_name] = (load_from_disk("./oursdecanotask/{}-eval.hf".format(cur_dataset)).select(range(200)),load_from_disk("./oursdecanotask/{}-evalex.hf".format(cur_dataset)).select(range(200)))
for ds_name in ["cnn_dailymail","multinli.in.out","zre"]:
eval_dataloaders[ds_name] = (load_from_disk("./oursdecanotask/{}-eval.hf".format(ds_name)),load_from_disk("./oursdecanotask/{}-evalex.hf".format(ds_name)))
pre_tasks = []
pre_general = []
pre_test = []
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
task_sequence = ["squad","wikisql","sst","srl","woz.en"]
# task_sequence = ["woz.en","srl","sst","wikisql","squad"]
need_to_do_dss = []
fileout = open("diana_log.txt",'w')
all_replay = None
all_features = []
all_ids = []
cluster_num=0
for cur_dataset in task_sequence:
cluster_num+=5
pre_tasks.append(cur_dataset)
if cur_dataset==task_sequence[-1]:
pre_tasks.extend(["cnn_dailymail","multinli.in.out","zre"])
data_args.dataset_name = cur_dataset
logger.info("current_dataset:"+cur_dataset)
training_args.do_train = True
training_args.to_eval = False
metric = load_metric("metric/squad_v1_local/squad_v1_local.py")
if all_replay is not None:
fused = datasets.concatenate_datasets([all_replay,train_dataloaders[cur_dataset]])
else:
fused = train_dataloaders[cur_dataset]
training_args.num_train_epochs = 5
model.encoder.reset_train_count()
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=fused,
eval_dataset=None,
eval_examples=None,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
train_result = trainer.train()
if training_args.local_rank<=0:
save_set,features = save_load_diverse_sample(model,train_dataloaders[cur_dataset])
if all_replay is None:
all_replay = save_set
else:
all_replay = datasets.concatenate_datasets([all_replay,save_set])
if all_features==[]:
all_features=features
else:
all_features.extend(features)
np.save("./all_features.npy",np.array(all_features))
all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset))
if training_args.local_rank!=-1:
torch.distributed.barrier()
all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset))
all_ids.extend([task2id[cur_dataset]]*50)
all_features=np.load("./all_features.npy").tolist()
model.encoder.add_negs(all_ids,all_features)
for pre_dataset in pre_tasks:
data_args.dataset_name = pre_dataset
metric = load_metric("metric/squad_v1_local/squad_v1_local.py")
eval_dataset,eval_examples = eval_dataloaders[pre_dataset]
trainer = QuestionAnsweringTrainer(
model=model,
args=training_args,
train_dataset=None,
eval_dataset=eval_dataset,
eval_examples=eval_examples,
answer_column_name=answer_column,
dataset_name=data_args.dataset_name,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
torch.cuda.empty_cache()
logger.info("*** Evaluate:{} ***".format(data_args.dataset_name))
max_length, num_beams, ignore_keys_for_eval = None, None, None
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.local_rank<=0:
try:
print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout)
print(metrics,file=fileout)
except:
pass
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
if len(languages) > 0:
kwargs["language"] = languages
return None
# -
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,53 @@
function fulldata_hfdata() {
SAVING_PATH=$1
PRETRAIN_MODEL_PATH=$2
TRAIN_BATCH_SIZE=$3
EVAL_BATCH_SIZE=$4
mkdir -p ${SAVING_PATH}
python -m torch.distributed.launch --nproc_per_node=4 dianawotask.py \
--model_name_or_path ${PRETRAIN_MODEL_PATH} \
--output_dir ${SAVING_PATH} \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size ${TRAIN_BATCH_SIZE} \
--per_device_eval_batch_size ${EVAL_BATCH_SIZE} \
--overwrite_output_dir \
--gradient_accumulation_steps 2 \
--num_train_epochs 5 \
--warmup_ratio 0.1 \
--logging_steps 5 \
--learning_rate 1e-4 \
--predict_with_generate \
--num_beams 4 \
--save_strategy no \
--evaluation_strategy no \
--weight_decay 1e-2 \
--max_source_length 512 \
--label_smoothing_factor 0.1 \
--do_lowercase True \
--load_best_model_at_end True \
--greater_is_better True \
--save_total_limit 10 \
--ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log
}
SAVING_PATH=$1
TRAIN_BATCH_SIZE=$2
EVAL_BATCH_SIZE=$3
MODE=try
SAVING_PATH=${SAVING_PATH}/lifelong/${MODE}
fulldata_hfdata ${SAVING_PATH} t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE}

View File

@ -0,0 +1,169 @@
import itertools
import json
import os
import csv
import errno
import random
from random import shuffle
from typing import List
import codecs
import nltk
import glob
import xml.etree.ElementTree as ET
from datasets import load_dataset
#nltk.download("stopwords")
from nltk.corpus import stopwords
def preprocess_function(examples):
preprocess_fn = preprocess_all#dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, "input","output")
model_inputs = tokenizer(inputs, max_length=1024, padding=padding, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=128, padding=padding, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def preprocess_function_test(examples):
preprocess_fn = preprocess_all#dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, "input","output")
model_inputs = tokenizer(inputs, max_length=1024, padding=padding, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=128, padding=padding, truncation=True)
model_inputs["example_id"] = []
model_inputs["id"] = []
for i in range(len(model_inputs["input_ids"])):
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = i #sample_mapping[i]
model_inputs["example_id"].append(i)
model_inputs["id"].append(i)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def add_id(example,index):
example.update({'id':index})
return example
def prep(raw_ds,fname):
ds = []
dss = map(lambda x: x["paragraphs"], raw_ds["data"])
for dd in dss:
ds.extend(dd)
print(len(ds))
print(len(raw_ds["data"]))
examples = {"context":[],"question":[],"answer":[]}
fout = open("{}-train.jsonl".format(fname),'w')
for d in ds:
context = d["context"]
#TOKENIZER.encode(d["context"])
for qa in d["qas"]:
question = qa["question"]#TOKENIZER.encode(qa["question"])
raw_answers = qa["answers"]
if len(raw_answers) == 0:
assert qa["is_impossible"]
raw_answers.append({"text": ""})
answers = []
for i, raw_answer in enumerate(raw_answers):
answers.append(raw_answer["text"])
jsonline = json.dumps({"question":question,"context":context,"answer":answers})
print(jsonline,file=fout)
fout.close()
def prep_test(raw_ds,use_answers,fname):
ds = []
dss = map(lambda x: x["paragraphs"], raw_ds["data"])
for dd in dss:
ds.extend(dd)
print(len(ds))
print(len(raw_ds["data"]))
fout = open("{}-eval.jsonl".format(fname),'w')
idx = 0
f_answers = []
use_answers = None
all_json_lines = []
for d in ds:
context = d["context"]
#TOKENIZER.encode(d["context"])
for qa in d["qas"]:
question = qa["question"]#TOKENIZER.encode(qa["question"])
raw_answers = qa["answers"]
f_answers.extend([_["text"] for _ in qa["answers"]])
if True:
if len(raw_answers) == 0:
assert qa["is_impossible"]
raw_answers.append({"text": ""})
answers = []
for i, raw_answer in enumerate(raw_answers):
answers.append(raw_answer["text"])
all_json_lines.append({"question":question,"context":context,"answer":answers,"preid":qa["id"]})
if fname in ["wikisql","woz.en","multinli.in.out"]:
all_json_lines.sort(key=lambda x: x["preid"])
for item in all_json_lines:
jsonline = json.dumps(item)
print(jsonline,file=fout)
fout.close()
from collections import Counter
import json
def main():
for item in ["wikisql","squad","srl","sst","woz.en","iwslt.en.de","cnn_dailymail","multinli.in.out","zre"]:
print("current_dataset:",item)
print("Loading......")
testfile = open(item+"_to_squad-test-v2.0.json",'r')
if not item in ["cnn_dailymail","multinli.in.out","zre"]:
trainfile = open(item+"_to_squad-train-v2.0.json",'r')
ftrain = json.load(trainfile)
ftest = json.load(testfile)
faddition = None
if item in ["woz.en","wikisql"]:
addition_answers = open(item+"_answers.json",'r')
faddition = json.load(addition_answers)
if item=="woz.en":
faddition = [[_[1]] for _ in faddition]
else:
faddition = [[_["answer"]] for _ in faddition]
else:
faddition = None
print("Loaded.")
if not item in ["cnn_dailymail","multinli.in.out","zre"]:
prep(ftrain,item)
prep_test(ftest,faddition,item)
main()

View File

@ -0,0 +1,557 @@
from transformers import Seq2SeqTrainer, is_torch_tpu_available, EvalPrediction
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import nltk
import datasets
import re
import os
import numpy as np
import torch
import random
from pathlib import Path
import nltk
# nltk.download('punkt')
from transformers.trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
EvalLoopOutput,
EvalPrediction,
HPSearchBackend,
HubStrategy,
IntervalStrategy,
PredictionOutput,
ShardedDDPOption,
TrainerMemoryTracker,
TrainOutput,
default_compute_objective,
default_hp_space,
denumpify_detensorize,
get_last_checkpoint,
number_of_arguments,
set_seed,
speed_metrics,
)
import warnings
from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
ShardSampler,
distributed_broadcast_scalars,
distributed_concat,
find_batch_size,
get_parameter_names,
nested_concat,
nested_detach,
nested_numpify,
nested_truncate,
nested_xla_mesh_reduce,
reissue_pt_warnings,
)
from transformers.file_utils import (
CONFIG_NAME,
WEIGHTS_NAME,
get_full_repo_name,
is_apex_available,
is_datasets_available,
is_in_notebook,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
)
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
from transformers.trainer_utils import PredictionOutput,EvalLoopOutput
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
def fix_buggy_characters(str):
return re.sub("[{}^\\\\`\u2047<]", " ", str)
def replace_punctuation(str):
return str.replace("\"", "").replace("'", "")
def score_string_similarity(str1, str2):
if str1 == str2:
return 3.0 # Better than perfect token match
str1 = fix_buggy_characters(replace_punctuation(str1))
str2 = fix_buggy_characters(replace_punctuation(str2))
if str1 == str2:
return 2.0
if " " in str1 or " " in str2:
str1_split = str1.split(" ")
str2_split = str2.split(" ")
overlap = list(set(str1_split) & set(str2_split))
return len(overlap) / max(len(str1_split), len(str2_split))
else:
if str1 == str2:
return 1.0
else:
return 0.0
class QuestionAnsweringTrainer(Seq2SeqTrainer):
def __init__(self, *args, tokenizer,eval_examples=None,answer_column_name='answers',dataset_name='squad', **kwargs):
super().__init__(*args, **kwargs)
self.eval_examples = eval_examples
self.answer_column_name = answer_column_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
if "cnn" in dataset_name:
self.post_process_function = self._post_process_narrative_qa
else:
self.post_process_function = self._post_process_squad
def _post_process_squad(
self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval"
,version_2_with_negative=False):
# Decode the predicted tokens.
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
# Build a map example to its corresponding features.
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)}
predictions = {}
# Let's loop over all the examples!
for example_index, example in enumerate(examples):
# This is the index of the feature associated to the current example.
try:
feature_index = feature_per_example[example_index]
except:
print(feature_per_example.keys())
print(example_index)
assert False
predictions[example["id"]] = decoded_preds[feature_index]
# Format the result to the format the metric expects.
if version_2_with_negative:
formatted_predictions = [
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
]
else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
def _post_process_squad_v2(
self, examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval"):
# Decode the predicted tokens.
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
# Build a map example to its corresponding features.
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)}
predictions = {}
# Let's loop over all the examples!
for example_index, example in enumerate(examples):
# This is the index of the feature associated to the current example.
feature_index = feature_per_example[example_index]
predictions[example["id"]] = decoded_preds[feature_index]
# Format the result to the format the metric expects.
formatted_predictions = [
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
]
references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
@classmethod
def postprocess_text(cls,preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels
def _post_process_boolq(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
# preds = [" ".join(pred) for pred in decoded_preds]
preds = decoded_preds
outputs = []
# for pred in preds:
# if 'True' == pred:
# outputs.append(1)
# elif 'False' == pred:
# outputs.append(0)
# else:
# outputs.append(0)
# references = [1 if ex['answer'] else 0 for ex in examples]
references = []
for pred, ref in zip(preds, examples):
ans = ref['answer']
references.append(1 if ans else 0)
if pred.lower() in ['true', 'false']:
outputs.append(1 if pred.lower() == 'true' else 0)
else:
# generate a wrong prediction if pred is not true/false
outputs.append(0 if ans else 1)
assert(len(references)==len(outputs))
formatted_predictions = []
return EvalPrediction(predictions=outputs, label_ids=references)
def _post_process_narrative_qa(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
# preds = [" ".join(pred) for pred in decoded_preds]
preds = decoded_preds
references = [exp['answers']['text'][0] for exp in examples]
formatted_predictions = []
preds, references = self.postprocess_text(preds,references)
assert(len(preds)==len(references))
return EvalPrediction(predictions=preds, label_ids=references)
def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
inputs["train"]=True
else:
labels = None
outputs,sim_loss = model(**inputs)
# print(model.encoder.)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
loss = self.label_smoother(outputs, labels)
else:
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
loss += sim_loss*0.1
return (loss, outputs) if return_outputs else loss
def _post_process_drop(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
# preds = [" ".join(pred) for pred in decoded_preds]
preds = decoded_preds
references = [exp['answers'][0]['text'] for exp in examples]
formatted_predictions = []
preds, references = self.postprocess_text(preds, references)
assert (len(preds) == len(references))
return EvalPrediction(predictions=preds, label_ids=references)
def _post_process_race(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
references = [{"answer": ex['answer'], 'options':ex['options']} for ex in examples]
assert(len(references)==len(decoded_preds))
gold_ids , pred_ids = [],[]
for prediction, reference in zip(decoded_preds, references):
# reference = json.loads(reference)
gold = int(ord(reference['answer'].strip()) - ord('A'))
options = reference['options']
prediction = prediction.replace("\n", "").strip()
options = [opt.strip() for opt in options if len(opt) > 0]
# print('options',options,type(options))
# print('prediction',prediction)
# print('answer',gold)
scores = [score_string_similarity(opt, prediction) for opt in options]
max_idx = np.argmax(scores)
gold_ids.append(gold)
pred_ids.append(max_idx)
# selected_ans = chr(ord('A') + max_idx)
# print(len(references),len(decoded_preds))
return EvalPrediction(predictions=pred_ids,label_ids = gold_ids)
def _post_process_openbookqa(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
references = [{"answer": ex['answerKey'], 'options': ex['choices']['text']} for ex in examples]
assert len(references) == len(decoded_preds)
gold_ids, pred_ids = [], []
for prediction, reference in zip(decoded_preds, references):
# reference = json.loads(reference)
gold = int(ord(reference['answer'].strip()) - ord('A'))
options = reference['options']
prediction = prediction.replace("\n", "").strip()
options = [opt.strip() for opt in options if len(opt) > 0]
scores = [score_string_similarity(opt, prediction) for opt in options]
max_idx = np.argmax(scores)
gold_ids.append(gold)
pred_ids.append(max_idx)
return EvalPrediction(predictions=pred_ids,label_ids = gold_ids)
def _post_process_dream(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
preds = outputs.predictions
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
references = [{"answer": ex['answer'], 'options': ex['choice']} for ex in examples]
assert(len(references)==len(decoded_preds))
gold_ids, pred_ids = [],[]
for prediction, reference in zip(decoded_preds, references):
gold = reference['options'].index(reference['answer'])
options = reference['options']
prediction = prediction.replace("\n", "").strip()
options = [opt.strip() for opt in options if len(opt) > 0]
scores = [score_string_similarity(opt, prediction) for opt in options]
max_idx = np.argmax(scores)
gold_ids.append(gold)
pred_ids.append(max_idx)
return EvalPrediction(predictions=pred_ids, label_ids = gold_ids)
def evaluate(self, eval_dataset=None, eval_examples=None, tokenizer=None,ignore_keys=None, metric_key_prefix: str = "eval",
max_length: Optional[int] = None,num_beams: Optional[int] = None):
gen_kwargs = {}
gen_kwargs["max_length"] = (
max_length
)
gen_kwargs["num_beams"] = (
num_beams
)
self._gen_kwargs = gen_kwargs
self._memory_tracker.start()
self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
eval_dataloader = self.get_eval_dataloader(eval_dataset)
eval_examples = self.eval_examples if eval_examples is None else eval_examples
# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try:
output = eval_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
)
finally:
self.compute_metrics = compute_metrics
if self.post_process_function is not None and self.compute_metrics is not None:
eval_preds = self.post_process_function(eval_examples, eval_dataset, output,tokenizer)
metrics = self.compute_metrics(eval_preds,
tri=self.dataset_name=="multinli.in.out",
dialogue=self.dataset_name=='woz.en',
rouge=self.dataset_name=='cnn_dailymail',
logical_form=self.dataset_name=='wikisql',
corpus_f1=self.dataset_name=='zre')
if self.dataset_name=='narrativeqa':
metrics = {key: value.mid.fmeasure * 100 for key, value in metrics.items()}
metrics = {k: round(v, 4) for k, v in metrics.items()}
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
print(metrics)
else:
metrics = {}
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
self._memory_tracker.stop_and_update_metrics(metrics)
return metrics
def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save except FullyShardedDDP.
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
# Save model checkpoint
save_only_best = True
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is not None and trial is not None:
if self.hp_search_backend == HPSearchBackend.OPTUNA:
run_id = trial.number
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune
run_id = tune.get_trial_id()
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
run_id = trial.id
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
run_dir = os.path.join(self.args.output_dir, run_name)
else:
run_dir = self.args.output_dir
self.store_flos()
output_dir = os.path.join(run_dir, checkpoint_folder)
if not save_only_best:
self.save_model(output_dir)
if self.deepspeed:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_fp16_weights_on_model_save` is True
self.deepspeed.save_checkpoint(output_dir)
# Save optimizer and scheduler
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer.consolidate_state_dict()
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
if smp.dp_rank() == 0:
# Consolidate the state dict on all processed of dp_rank 0
opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process
if self.args.should_save and not save_only_best:
torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.deepspeed and not save_only_best:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
self.do_grad_scaling=True
# if self.do_grad_scaling:
# torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics[metric_to_check]
operator = np.greater if self.args.greater_is_better else np.less
if (
self.state.best_metric is None
or self.state.best_model_checkpoint is None
or operator(metric_value, self.state.best_metric)
):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
best_dir = os.path.join(run_dir,'best-checkpoint')
# if not os.path.exists(best_dir):
# os.mkdir(best_dir,exist_ok=True)
if self.args.should_save:
self.save_model(best_dir)
self.state.save_to_json(os.path.join(best_dir, TRAINER_STATE_NAME))
# Save the Trainer state
if self.args.should_save and not save_only_best:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
# Save RNG state in non-distributed training
rng_states = {
"python": random.getstate(),
"numpy": np.random.get_state(),
"cpu": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
if self.args.local_rank == -1:
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
else:
rng_states["cuda"] = torch.cuda.random.get_rng_state()
if is_torch_tpu_available():
rng_states["xla"] = xm.get_rng_state()
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
if local_rank == -1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
# Maybe delete some older checkpoints.
if self.args.should_save and not save_only_best:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
# def predict(self, predict_dataset=None, predict_examples=None, tokenizer=None,ignore_keys=None, metric_key_prefix: str = "predict",
# max_length: Optional[int] = None,num_beams: Optional[int] = None):
# self._memory_tracker.start()
# self._max_length = max_length if max_length is not None else self.args.generation_max_length
# self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
# predict_dataset = self.predict_dataset if predict_dataset is None else predict_dataset
# predict_dataloader = self.get_eval_dataloader(predict_dataset)
# predict_examples = predict_examples
#
# # Temporarily disable metric computation, we will do it in the loop here.
# compute_metrics = self.compute_metrics
# self.compute_metrics = None
# eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
# try:
# output = eval_loop(
# predict_dataloader,
# description="Predict",
# # No point gathering the predictions if there are no metrics, otherwise we defer to
# # self.args.prediction_loss_only
# prediction_loss_only=True if compute_metrics is None else None,
# ignore_keys=ignore_keys,
# )
# finally:
# self.compute_metrics = compute_metrics
#
# if self.post_process_function is not None and self.compute_metrics is not None:
# predict_preds = self.post_process_function(predict_examples, predict_dataset, output,tokenizer)
# metrics = self.compute_metrics(predict_preds)
# if self.dataset_name=='narrativeqa':
# metrics = {key: value.mid.fmeasure * 100 for key, value in metrics.items()}
# metrics = {k: round(v, 4) for k, v in metrics.items()}
# # Prefix all keys with metric_key_prefix + '_'
# for key in list(metrics.keys()):
# if not key.startswith(f"{metric_key_prefix}_"):
# metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
#
# self.log(metrics)
# else:
# metrics = {}
#
# self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
# self._memory_tracker.stop_and_update_metrics(metrics)
# return metrics

View File

@ -0,0 +1,105 @@
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Accuracy metric."""
from sklearn.metrics import accuracy_score
import datasets
_DESCRIPTION = """
Accuracy is the proportion of correct predictions among the total number of cases processed. It can be computed with:
Accuracy = (TP + TN) / (TP + TN + FP + FN)
Where:
TP: True positive
TN: True negative
FP: False positive
FN: False negative
"""
_KWARGS_DESCRIPTION = """
Args:
predictions (`list` of `int`): Predicted labels.
references (`list` of `int`): Ground truth labels.
normalize (`boolean`): If set to False, returns the number of correctly classified samples. Otherwise, returns the fraction of correctly classified samples. Defaults to True.
sample_weight (`list` of `float`): Sample weights Defaults to None.
Returns:
accuracy (`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher score means higher accuracy.
Examples:
Example 1-A simple example
>>> accuracy_metric = datasets.load_metric("accuracy")
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0])
>>> print(results)
{'accuracy': 0.5}
Example 2-The same as Example 1, except with `normalize` set to `False`.
>>> accuracy_metric = datasets.load_metric("accuracy")
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], normalize=False)
>>> print(results)
{'accuracy': 3.0}
Example 3-The same as Example 1, except with `sample_weight` set.
>>> accuracy_metric = datasets.load_metric("accuracy")
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], sample_weight=[0.5, 2, 0.7, 0.5, 9, 0.4])
>>> print(results)
{'accuracy': 0.8778625954198473}
"""
_CITATION = """
@article{scikit-learn,
title={Scikit-learn: Machine Learning in {P}ython},
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
journal={Journal of Machine Learning Research},
volume={12},
pages={2825--2830},
year={2011}
}
"""
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Accuracy(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("int32")),
"references": datasets.Sequence(datasets.Value("int32")),
}
if self.config_name == "multilabel"
else {
"predictions": datasets.Value("int32"),
"references": datasets.Value("int32"),
}
),
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
)
def _compute(self, predictions, references, normalize=True, sample_weight=None):
return {
"accuracy": float(
accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight)
)
}

View File

@ -0,0 +1,126 @@
# Copyright 2020 The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BLEU metric. """
import datasets
from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
_CITATION = """\
@INPROCEEDINGS{Papineni02bleu:a,
author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu},
title = {BLEU: a Method for Automatic Evaluation of Machine Translation},
booktitle = {},
year = {2002},
pages = {311--318}
}
@inproceedings{lin-och-2004-orange,
title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation",
author = "Lin, Chin-Yew and
Och, Franz Josef",
booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics",
month = "aug 23{--}aug 27",
year = "2004",
address = "Geneva, Switzerland",
publisher = "COLING",
url = "https://www.aclweb.org/anthology/C04-1072",
pages = "501--507",
}
"""
_DESCRIPTION = """\
BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another.
Quality is considered to be the correspondence between a machine's output and that of a human: "the closer a machine translation is to a professional human translation,
the better it is" this is the central idea behind BLEU. BLEU was one of the first metrics to claim a high correlation with human judgements of quality, and
remains one of the most popular automated and inexpensive metrics.
Scores are calculated for individual translated segmentsgenerally sentencesby comparing them with a set of good quality reference translations.
Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. Intelligibility or grammatical correctness
are not taken into account[citation needed].
BLEU's output is always a number between 0 and 1. This value indicates how similar the candidate text is to the reference texts, with values closer to 1
representing more similar texts. Few human translations will attain a score of 1, since this would indicate that the candidate is identical to one of the
reference translations. For this reason, it is not necessary to attain a score of 1. Because there are more opportunities to match, adding additional
reference translations will increase the BLEU score.
"""
_KWARGS_DESCRIPTION = """
Computes BLEU score of translated segments against one or more references.
Args:
predictions: list of translations to score.
Each translation should be tokenized into a list of tokens.
references: list of lists of references for each translation.
Each reference should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
smooth: Whether or not to apply Lin et al. 2004 smoothing.
Returns:
'bleu': bleu score,
'precisions': geometric mean of n-gram precisions,
'brevity_penalty': brevity penalty,
'length_ratio': ratio of lengths,
'translation_length': translation_length,
'reference_length': reference_length
Examples:
>>> predictions = [
... ["hello", "there", "general", "kenobi"], # tokenized prediction of the first sample
... ["foo", "bar", "foobar"] # tokenized prediction of the second sample
... ]
>>> references = [
... [["hello", "there", "general", "kenobi"], ["hello", "there", "!"]], # tokenized references for the first sample (2 references)
... [["foo", "bar", "foobar"]] # tokenized references for the second sample (1 reference)
... ]
>>> bleu = datasets.load_metric("bleu")
>>> results = bleu.compute(predictions=predictions, references=references)
>>> print(results["bleu"])
1.0
"""
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Bleu(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("string", id="token"), id="sequence"),
"references": datasets.Sequence(
datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), id="references"
),
}
),
codebase_urls=["https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py"],
reference_urls=[
"https://en.wikipedia.org/wiki/BLEU",
"https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213",
],
)
def _compute(self, predictions, references, max_order=4, smooth=False):
score = compute_bleu(
reference_corpus=references, translation_corpus=predictions, max_order=max_order, smooth=smooth
)
(bleu, precisions, bp, ratio, translation_length, reference_length) = score
return {
"bleu": bleu,
"precisions": precisions,
"brevity_penalty": bp,
"length_ratio": ratio,
"translation_length": translation_length,
"reference_length": reference_length,
}

View File

@ -0,0 +1,219 @@
# Copyright 2020 The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BLEU metric. """
import datasets
#from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
_CITATION = """\
@INPROCEEDINGS{Papineni02bleu:a,
author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu},
title = {BLEU: a Method for Automatic Evaluation of Machine Translation},
booktitle = {},
year = {2002},
pages = {311--318}
}
@inproceedings{lin-och-2004-orange,
title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation",
author = "Lin, Chin-Yew and
Och, Franz Josef",
booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics",
month = "aug 23{--}aug 27",
year = "2004",
address = "Geneva, Switzerland",
publisher = "COLING",
url = "https://www.aclweb.org/anthology/C04-1072",
pages = "501--507",
}
"""
_DESCRIPTION = """\
BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another.
Quality is considered to be the correspondence between a machine's output and that of a human: "the closer a machine translation is to a professional human translation,
the better it is" this is the central idea behind BLEU. BLEU was one of the first metrics to claim a high correlation with human judgements of quality, and
remains one of the most popular automated and inexpensive metrics.
Scores are calculated for individual translated segmentsgenerally sentencesby comparing them with a set of good quality reference translations.
Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. Intelligibility or grammatical correctness
are not taken into account[citation needed].
BLEU's output is always a number between 0 and 1. This value indicates how similar the candidate text is to the reference texts, with values closer to 1
representing more similar texts. Few human translations will attain a score of 1, since this would indicate that the candidate is identical to one of the
reference translations. For this reason, it is not necessary to attain a score of 1. Because there are more opportunities to match, adding additional
reference translations will increase the BLEU score.
"""
_KWARGS_DESCRIPTION = """
Computes BLEU score of translated segments against one or more references.
Args:
predictions: list of translations to score.
Each translation should be tokenized into a list of tokens.
references: list of lists of references for each translation.
Each reference should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
smooth: Whether or not to apply Lin et al. 2004 smoothing.
Returns:
'bleu': bleu score,
'precisions': geometric mean of n-gram precisions,
'brevity_penalty': brevity penalty,
'length_ratio': ratio of lengths,
'translation_length': translation_length,
'reference_length': reference_length
Examples:
>>> predictions = [
... ["hello", "there", "general", "kenobi"], # tokenized prediction of the first sample
... ["foo", "bar", "foobar"] # tokenized prediction of the second sample
... ]
>>> references = [
... [["hello", "there", "general", "kenobi"], ["hello", "there", "!"]], # tokenized references for the first sample (2 references)
... [["foo", "bar", "foobar"]] # tokenized references for the second sample (1 reference)
... ]
>>> bleu = datasets.load_metric("bleu")
>>> results = bleu.compute(predictions=predictions, references=references)
>>> print(results["bleu"])
1.0
"""
import collections
import math
def _get_ngrams(segment, max_order):
"""Extracts all n-grams upto a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams upto max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i:i+order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus, translation_corpus, max_order=1,
smooth=False):
"""Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of lists of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
smooth: Whether or not to apply Lin et al. 2004 smoothing.
Returns:
3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
precisions and brevity penalty.
"""
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
reference_length = 0
translation_length = 0
for (references, translation) in zip(reference_corpus,
translation_corpus):
# translation = references[0]
translation = [item.lower() for item in translation]
reference_length += min(len(r) for r in references)
translation_length += len(translation)
merged_ref_ngram_counts = collections.Counter()
for reference in references:
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
# print(merged_ref_ngram_counts)
translation_ngram_counts = _get_ngrams(translation, max_order)
overlap = translation_ngram_counts & merged_ref_ngram_counts
for ngram in overlap:
matches_by_order[len(ngram)-1] += overlap[ngram]
for order in range(1, max_order+1):
possible_matches = len(translation) - order + 1
if possible_matches > 0:
possible_matches_by_order[order-1] += possible_matches
precisions = [0] * max_order
for i in range(0, max_order):
if smooth:
precisions[i] = ((matches_by_order[i] + 1.) /
(possible_matches_by_order[i] + 1.))
else:
if possible_matches_by_order[i] > 0:
precisions[i] = (float(matches_by_order[i]) /
possible_matches_by_order[i])
else:
precisions[i] = 0.0
if min(precisions) > 0:
p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
geo_mean = math.exp(p_log_sum)
else:
geo_mean = 0
ratio = float(translation_length) / reference_length
if ratio > 1.0:
bp = 1.
else:
bp = math.exp(1 - 1. / ratio)
bleu = geo_mean * bp
return (bleu, precisions, bp, ratio, translation_length, reference_length)
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Bleu(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("string", id="token"), id="sequence"),
"references": datasets.Sequence(
datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), id="references"
),
}
),
codebase_urls=["https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py"],
reference_urls=[
"https://en.wikipedia.org/wiki/BLEU",
"https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213",
],
)
def _compute(self, predictions, references, max_order=4, smooth=False):
score = compute_bleu(
reference_corpus=references, translation_corpus=predictions, max_order=1, smooth=smooth
)
(bleu, precisions, bp, ratio, translation_length, reference_length) = score
return {
"bleu": bleu,
"precisions": precisions,
"brevity_penalty": bp,
"length_ratio": ratio,
"translation_length": translation_length,
"reference_length": reference_length,
}

View File

@ -0,0 +1,133 @@
# Copyright 2020 The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" ROUGE metric from Google Research github repo. """
# The dependencies in https://github.com/google-research/google-research/blob/master/rouge/requirements.txt
import absl # Here to have a nice missing dependency error message early on
import nltk # Here to have a nice missing dependency error message early on
import numpy # Here to have a nice missing dependency error message early on
import six # Here to have a nice missing dependency error message early on
from .rouge_score import RougeScorer
from .scoring import *
from .rouge_tokenizers import *
import datasets
_CITATION = """\
@inproceedings{lin-2004-rouge,
title = "{ROUGE}: A Package for Automatic Evaluation of Summaries",
author = "Lin, Chin-Yew",
booktitle = "Text Summarization Branches Out",
month = jul,
year = "2004",
address = "Barcelona, Spain",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/W04-1013",
pages = "74--81",
}
"""
_DESCRIPTION = """\
ROUGE, or Recall-Oriented Understudy for Gisting Evaluation, is a set of metrics and a software package used for
evaluating automatic summarization and machine translation software in natural language processing.
The metrics compare an automatically produced summary or translation against a reference or a set of references (human-produced) summary or translation.
Note that ROUGE is case insensitive, meaning that upper case letters are treated the same way as lower case letters.
This metrics is a wrapper around Google Research reimplementation of ROUGE:
https://github.com/google-research/google-research/tree/master/rouge
"""
_KWARGS_DESCRIPTION = """
Calculates average rouge scores for a list of hypotheses and references
Args:
predictions: list of predictions to score. Each prediction
should be a string with tokens separated by spaces.
references: list of reference for each prediction. Each
reference should be a string with tokens separated by spaces.
rouge_types: A list of rouge types to calculate.
Valid names:
`"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
`"rougeL"`: Longest common subsequence based scoring.
`"rougeLSum"`: rougeLsum splits text using `"\n"`.
See details in https://github.com/huggingface/datasets/issues/617
use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
use_aggregator: Return aggregates if this is set to True
Returns:
rouge1: rouge_1 (precision, recall, f1),
rouge2: rouge_2 (precision, recall, f1),
rougeL: rouge_l (precision, recall, f1),
rougeLsum: rouge_lsum (precision, recall, f1)
Examples:
>>> rouge = datasets.load_metric('rouge')
>>> predictions = ["hello there", "general kenobi"]
>>> references = ["hello there", "general kenobi"]
>>> results = rouge.compute(predictions=predictions, references=references)
>>> print(list(results.keys()))
['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
>>> print(results["rouge1"])
AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))
>>> print(results["rouge1"].mid.fmeasure)
1.0
"""
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Rouge(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Value("string", id="sequence"),
}
),
codebase_urls=["https://github.com/google-research/google-research/tree/master/rouge"],
reference_urls=[
"https://en.wikipedia.org/wiki/ROUGE_(metric)",
"https://github.com/google-research/google-research/tree/master/rouge",
],
)
def _compute(self, predictions, references, rouge_types=None, use_aggregator=True, use_stemmer=False):
if rouge_types is None:
rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
scorer = RougeScorer(rouge_types=rouge_types, use_stemmer=use_stemmer)
if use_aggregator:
aggregator = BootstrapAggregator()
else:
scores = []
for ref, pred in zip(references, predictions):
score = scorer.score(ref, pred)
if use_aggregator:
aggregator.add_scores(score)
else:
scores.append(score)
if use_aggregator:
result = aggregator.aggregate()
for key in rouge_types:
result[key] = result[key].mid.fmeasure
else:
result = {}
for key in scores[0]:
result[key] = list(score[key] for score in scores)
return result

View File

@ -0,0 +1,318 @@
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Computes rouge scores between two text blobs.
Implementation replicates the functionality in the original ROUGE package. See:
Lin, Chin-Yew. ROUGE: a Package for Automatic Evaluation of Summaries. In
Proceedings of the Workshop on Text Summarization Branches Out (WAS 2004),
Barcelona, Spain, July 25 - 26, 2004.
Default options are equivalent to running:
ROUGE-1.5.5.pl -e data -n 2 -a settings.xml
Or with use_stemmer=True:
ROUGE-1.5.5.pl -m -e data -n 2 -a settings.xml
In these examples settings.xml lists input files and formats.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
from absl import logging
import nltk
import numpy as np
import six
from six.moves import map
from six.moves import range
from .scoring import *
from .rouge_tokenizers import *
class RougeScorer(BaseScorer):
"""Calculate rouges scores between two blobs of text.
Sample usage:
scorer = RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
scores = scorer.score('The quick brown fox jumps over the lazy dog',
'The quick brown dog jumps on the log.')
"""
def __init__(self, rouge_types, use_stemmer=False, split_summaries=False,
tokenizer=None):
"""Initializes a new RougeScorer.
Valid rouge types that can be computed are:
rougen (e.g. rouge1, rouge2): n-gram based scoring.
rougeL: Longest common subsequence based scoring.
Args:
rouge_types: A list of rouge types to calculate.
use_stemmer: Bool indicating whether Porter stemmer should be used to
strip word suffixes to improve matching. This arg is used in the
DefaultTokenizer, but other tokenizers might or might not choose to
use this.
split_summaries: whether to add newlines between sentences for rougeLsum
tokenizer: Tokenizer object which has a tokenize() method.
Returns:
A dict mapping rouge types to Score tuples.
"""
self.rouge_types = rouge_types
if tokenizer:
self._tokenizer = tokenizer
else:
self._tokenizer = DefaultTokenizer(use_stemmer)
logging.info("Using default tokenizer.")
self._split_summaries = split_summaries
def score_multi(self, targets, prediction):
"""Calculates rouge scores between targets and prediction.
The target with the maximum f-measure is used for the final score for
each score type..
Args:
targets: list of texts containing the targets
prediction: Text containing the predicted text.
Returns:
A dict mapping each rouge type to a Score object.
Raises:
ValueError: If an invalid rouge type is encountered.
"""
score_dicts = [self.score(t, prediction) for t in targets]
max_score = {}
for k in self.rouge_types:
index = np.argmax([s[k].fmeasure for s in score_dicts])
max_score[k] = score_dicts[index][k]
return max_score
def score(self, target, prediction):
"""Calculates rouge scores between the target and prediction.
Args:
target: Text containing the target (ground truth) text,
or if a list
prediction: Text containing the predicted text.
Returns:
A dict mapping each rouge type to a Score object.
Raises:
ValueError: If an invalid rouge type is encountered.
"""
# Pre-compute target tokens and prediction tokens for use by different
# types, except if only "rougeLsum" is requested.
if len(self.rouge_types) == 1 and self.rouge_types[0] == "rougeLsum":
target_tokens = None
prediction_tokens = None
else:
target_tokens = self._tokenizer.tokenize(target)
prediction_tokens = self._tokenizer.tokenize(prediction)
result = {}
for rouge_type in self.rouge_types:
if rouge_type == "rougeL":
# Rouge from longest common subsequences.
scores = _score_lcs(target_tokens, prediction_tokens)
elif rouge_type == "rougeLsum":
# Note: Does not support multi-line text.
def get_sents(text):
if self._split_summaries:
sents = nltk.sent_tokenize(text)
else:
# Assume sentences are separated by newline.
sents = six.ensure_str(text).split("\n")
sents = [x for x in sents if len(x)]
return sents
target_tokens_list = [
self._tokenizer.tokenize(s) for s in get_sents(target)]
prediction_tokens_list = [
self._tokenizer.tokenize(s) for s in get_sents(prediction)]
scores = _summary_level_lcs(target_tokens_list,
prediction_tokens_list)
elif re.match(r"rouge[0-9]$", six.ensure_str(rouge_type)):
# Rouge from n-grams.
n = int(rouge_type[5:])
if n <= 0:
raise ValueError("rougen requires positive n: %s" % rouge_type)
target_ngrams = _create_ngrams(target_tokens, n)
prediction_ngrams = _create_ngrams(prediction_tokens, n)
scores = _score_ngrams(target_ngrams, prediction_ngrams)
else:
raise ValueError("Invalid rouge type: %s" % rouge_type)
result[rouge_type] = scores
return result
def _create_ngrams(tokens, n):
"""Creates ngrams from the given list of tokens.
Args:
tokens: A list of tokens from which ngrams are created.
n: Number of tokens to use, e.g. 2 for bigrams.
Returns:
A dictionary mapping each bigram to the number of occurrences.
"""
ngrams = collections.Counter()
for ngram in (tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)):
ngrams[ngram] += 1
return ngrams
def _score_lcs(target_tokens, prediction_tokens):
"""Computes LCS (Longest Common Subsequence) rouge scores.
Args:
target_tokens: Tokens from the target text.
prediction_tokens: Tokens from the predicted text.
Returns:
A Score object containing computed scores.
"""
if not target_tokens or not prediction_tokens:
return Score(precision=0, recall=0, fmeasure=0)
# Compute length of LCS from the bottom up in a table (DP appproach).
lcs_table = _lcs_table(target_tokens, prediction_tokens)
lcs_length = lcs_table[-1][-1]
precision = lcs_length / len(prediction_tokens)
recall = lcs_length / len(target_tokens)
fmeasure = func_fmeasure(precision, recall)
return Score(precision=precision, recall=recall, fmeasure=fmeasure)
def _lcs_table(ref, can):
"""Create 2-d LCS score table."""
rows = len(ref)
cols = len(can)
lcs_table = [[0] * (cols + 1) for _ in range(rows + 1)]
for i in range(1, rows + 1):
for j in range(1, cols + 1):
if ref[i - 1] == can[j - 1]:
lcs_table[i][j] = lcs_table[i - 1][j - 1] + 1
else:
lcs_table[i][j] = max(lcs_table[i - 1][j], lcs_table[i][j - 1])
return lcs_table
def _backtrack_norec(t, ref, can):
"""Read out LCS."""
i = len(ref)
j = len(can)
lcs = []
while i > 0 and j > 0:
if ref[i - 1] == can[j - 1]:
lcs.insert(0, i-1)
i -= 1
j -= 1
elif t[i][j - 1] > t[i - 1][j]:
j -= 1
else:
i -= 1
return lcs
def _summary_level_lcs(ref_sent, can_sent):
"""ROUGE: Summary-level LCS, section 3.2 in ROUGE paper.
Args:
ref_sent: list of tokenized reference sentences
can_sent: list of tokenized candidate sentences
Returns:
summary level ROUGE score
"""
if not ref_sent or not can_sent:
return Score(precision=0, recall=0, fmeasure=0)
m = sum(map(len, ref_sent))
n = sum(map(len, can_sent))
if not n or not m:
return Score(precision=0, recall=0, fmeasure=0)
# get token counts to prevent double counting
token_cnts_r = collections.Counter()
token_cnts_c = collections.Counter()
for s in ref_sent:
# s is a list of tokens
token_cnts_r.update(s)
for s in can_sent:
token_cnts_c.update(s)
hits = 0
for r in ref_sent:
lcs = _union_lcs(r, can_sent)
# Prevent double-counting:
# The paper describes just computing hits += len(_union_lcs()),
# but the implementation prevents double counting. We also
# implement this as in version 1.5.5.
for t in lcs:
if token_cnts_c[t] > 0 and token_cnts_r[t] > 0:
hits += 1
token_cnts_c[t] -= 1
token_cnts_r[t] -= 1
recall = hits / m
precision = hits / n
fmeasure = func_fmeasure(precision, recall)
return Score(precision=precision, recall=recall, fmeasure=fmeasure)
def _union_lcs(ref, c_list):
"""Find union LCS between a ref sentence and list of candidate sentences.
Args:
ref: list of tokens
c_list: list of list of indices for LCS into reference summary
Returns:
List of tokens in ref representing union LCS.
"""
lcs_list = [lcs_ind(ref, c) for c in c_list]
return [ref[i] for i in _find_union(lcs_list)]
def _find_union(lcs_list):
"""Finds union LCS given a list of LCS."""
return sorted(list(set().union(*lcs_list)))
def lcs_ind(ref, can):
"""Returns one of the longest lcs."""
t = _lcs_table(ref, can)
return _backtrack_norec(t, ref, can)
def _score_ngrams(target_ngrams, prediction_ngrams):
"""Compute n-gram based rouge scores.
Args:
target_ngrams: A Counter object mapping each ngram to number of
occurrences for the target text.
prediction_ngrams: A Counter object mapping each ngram to number of
occurrences for the prediction text.
Returns:
A Score object containing computed scores.
"""
intersection_ngrams_count = 0
for ngram in six.iterkeys(target_ngrams):
intersection_ngrams_count += min(target_ngrams[ngram],
prediction_ngrams[ngram])
target_ngrams_count = sum(target_ngrams.values())
prediction_ngrams_count = sum(prediction_ngrams.values())
precision = intersection_ngrams_count / max(prediction_ngrams_count, 1)
recall = intersection_ngrams_count / max(target_ngrams_count, 1)
fmeasure = func_fmeasure(precision, recall)
return Score(precision=precision, recall=recall, fmeasure=fmeasure)

View File

@ -0,0 +1,90 @@
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Library containing Tokenizer definitions.
The RougeScorer class can be instantiated with the tokenizers defined here. New
tokenizers can be defined by creating a subclass of the Tokenizer abstract class
and overriding the tokenize() method.
"""
import abc
from nltk.stem import porter
import re
import six
# Pre-compile regexes that are use often
NON_ALPHANUM_PATTERN = r"[^a-z0-9]+"
NON_ALPHANUM_RE = re.compile(NON_ALPHANUM_PATTERN)
SPACES_PATTERN = r"\s+"
SPACES_RE = re.compile(SPACES_PATTERN)
VALID_TOKEN_PATTERN = r"^[a-z0-9]+$"
VALID_TOKEN_RE = re.compile(VALID_TOKEN_PATTERN)
def _tokenize(text, stemmer):
"""Tokenize input text into a list of tokens.
This approach aims to replicate the approach taken by Chin-Yew Lin in
the original ROUGE implementation.
Args:
text: A text blob to tokenize.
stemmer: An optional stemmer.
Returns:
A list of string tokens extracted from input text.
"""
# Convert everything to lowercase.
text = text.lower()
# Replace any non-alpha-numeric characters with spaces.
text = NON_ALPHANUM_RE.sub(" ", six.ensure_str(text))
tokens = SPACES_RE.split(text)
if stemmer:
# Only stem words more than 3 characters long.
tokens = [six.ensure_str(stemmer.stem(x)) if len(x) > 3 else x
for x in tokens]
# One final check to drop any empty or invalid tokens.
tokens = [x for x in tokens if VALID_TOKEN_RE.match(x)]
return tokens
class Tokenizer(abc.ABC):
"""Abstract base class for a tokenizer.
Subclasses of Tokenizer must implement the tokenize() method.
"""
@abc.abstractmethod
def tokenize(self, text):
raise NotImplementedError("Tokenizer must override tokenize() method")
class DefaultTokenizer(Tokenizer):
"""Default tokenizer which tokenizes on whitespace."""
def __init__(self, use_stemmer=False):
"""Constructor for DefaultTokenizer.
Args:
use_stemmer: boolean, indicating whether Porter stemmer should be used to
strip word suffixes to improve matching.
"""
self._stemmer = porter.PorterStemmer() if use_stemmer else None
def tokenize(self, text):
return _tokenize(text, self._stemmer)

View File

@ -0,0 +1,158 @@
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library for scoring and evaluation of text samples.
Aggregation functions use bootstrap resampling to compute confidence intervals
as per the original ROUGE perl implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import collections
from typing import Dict
import numpy as np
import six
from six.moves import range
class Score(
collections.namedtuple("Score", ["precision", "recall", "fmeasure"])):
"""Tuple containing precision, recall, and f-measure values."""
class BaseScorer(object, metaclass=abc.ABCMeta):
"""Base class for Scorer objects."""
@abc.abstractmethod
def score(self, target, prediction):
"""Calculates score between the target and prediction.
Args:
target: Text containing the target (ground truth) text.
prediction: Text containing the predicted text.
Returns:
A dict mapping each score_type (string) to Score object.
"""
class AggregateScore(
collections.namedtuple("AggregateScore", ["low", "mid", "high"])):
"""Tuple containing confidence intervals for scores."""
class BootstrapAggregator(object):
"""Aggregates scores to provide confidence intervals.
Sample usage:
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'])
aggregator = Aggregator()
aggregator.add_scores(scorer.score("one two three", "one two"))
aggregator.add_scores(scorer.score("one two five six", "seven eight"))
result = aggregator.aggregate()
print result
{'rougeL': AggregateScore(
low=Score(precision=0.0, recall=0.0, fmeasure=0.0),
mid=Score(precision=0.5, recall=0.33, fmeasure=0.40),
high=Score(precision=1.0, recall=0.66, fmeasure=0.80)),
'rouge1': AggregateScore(
low=Score(precision=0.0, recall=0.0, fmeasure=0.0),
mid=Score(precision=0.5, recall=0.33, fmeasure=0.40),
high=Score(precision=1.0, recall=0.66, fmeasure=0.80))}
"""
def __init__(self, confidence_interval=0.95, n_samples=1000):
"""Initializes a BootstrapAggregator object.
Args:
confidence_interval: Confidence interval to compute on the mean as a
decimal.
n_samples: Number of samples to use for bootstrap resampling.
Raises:
ValueError: If invalid argument is given.
"""
if confidence_interval < 0 or confidence_interval > 1:
raise ValueError("confidence_interval must be in range [0, 1]")
if n_samples <= 0:
raise ValueError("n_samples must be positive")
self._n_samples = n_samples
self._confidence_interval = confidence_interval
self._scores = collections.defaultdict(list)
def add_scores(self, scores):
"""Adds a sample for future aggregation.
Args:
scores: Dict mapping score_type strings to a namedtuple object/class
representing a score.
"""
for score_type, score in six.iteritems(scores):
self._scores[score_type].append(score)
def aggregate(self):
"""Aggregates scores previously added using add_scores.
Returns:
A dict mapping score_type to AggregateScore objects.
"""
result = {}
for score_type, scores in six.iteritems(self._scores):
# Stack scores into a 2-d matrix of (sample, measure).
score_matrix = np.vstack(tuple(scores))
# Percentiles are returned as (interval, measure).
percentiles = self._bootstrap_resample(score_matrix)
# Extract the three intervals (low, mid, high).
intervals = tuple(
(scores[0].__class__(*percentiles[j, :]) for j in range(3)))
result[score_type] = AggregateScore(
low=intervals[0], mid=intervals[1], high=intervals[2])
return result
def _bootstrap_resample(self, matrix):
"""Performs bootstrap resampling on a matrix of scores.
Args:
matrix: A 2-d matrix of (sample, measure).
Returns:
A 2-d matrix of (bounds, measure). There are three bounds: low (row 0),
mid (row 1) and high (row 2). Mid is always the mean, while low and high
bounds are specified by self._confidence_interval (which defaults to 0.95
meaning it will return the 2.5th and 97.5th percentiles for a 95%
confidence interval on the mean).
"""
# Matrix of (bootstrap sample, measure).
sample_mean = np.zeros((self._n_samples, matrix.shape[1]))
for i in range(self._n_samples):
sample_idx = np.random.choice(
np.arange(matrix.shape[0]), size=matrix.shape[0])
sample = matrix[sample_idx, :]
sample_mean[i, :] = np.mean(sample, axis=0)
# Take percentiles on the estimate of the mean using bootstrap samples.
# Final result is a (bounds, measure) matrix.
percentile_delta = (1 - self._confidence_interval) / 2
q = 100 * np.array([percentile_delta, 0.5, 1 - percentile_delta])
return np.percentile(sample_mean, q, axis=0)
def func_fmeasure(precision, recall):
"""Computes f-measure given precision and recall values."""
if precision + recall > 0:
return 2 * precision * recall / (precision + recall)
else:
return 0.0

View File

@ -0,0 +1,92 @@
""" Official evaluation script for v1.1 of the SQuAD dataset. """
import argparse
import json
import re
import string
import sys
from collections import Counter
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth):
return normalize_answer(prediction) == normalize_answer(ground_truth)
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def evaluate(dataset, predictions):
f1 = exact_match = total = 0
for article in dataset:
for paragraph in article["paragraphs"]:
for qa in paragraph["qas"]:
total += 1
if qa["id"] not in predictions:
message = "Unanswered question " + qa["id"] + " will receive score 0."
print(message, file=sys.stderr)
continue
ground_truths = list(map(lambda x: x["text"], qa["answers"]))
prediction = predictions[qa["id"]]
exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
exact_match = 100.0 * exact_match / total
f1 = 100.0 * f1 / total
return {"exact_match": exact_match, "f1": f1}
if __name__ == "__main__":
expected_version = "1.1"
parser = argparse.ArgumentParser(description="Evaluation for SQuAD " + expected_version)
parser.add_argument("dataset_file", help="Dataset file")
parser.add_argument("prediction_file", help="Prediction File")
args = parser.parse_args()
with open(args.dataset_file) as dataset_file:
dataset_json = json.load(dataset_file)
if dataset_json["version"] != expected_version:
print(
"Evaluation expects v-" + expected_version + ", but got dataset with v-" + dataset_json["version"],
file=sys.stderr,
)
dataset = dataset_json["data"]
with open(args.prediction_file) as prediction_file:
predictions = json.load(prediction_file)
print(json.dumps(evaluate(dataset, predictions)))

View File

@ -0,0 +1,109 @@
# Copyright 2020 The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" SQuAD metric. """
import datasets
from .evaluate import evaluate
_CITATION = """\
@inproceedings{Rajpurkar2016SQuAD10,
title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text},
author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang},
booktitle={EMNLP},
year={2016}
}
"""
_DESCRIPTION = """
This metric wrap the official scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD).
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by
crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span,
from the corresponding reading passage, or the question might be unanswerable.
"""
_KWARGS_DESCRIPTION = """
Computes SQuAD scores (F1 and EM).
Args:
predictions: List of question-answers dictionaries with the following key-values:
- 'id': id of the question-answer pair as given in the references (see below)
- 'prediction_text': the text of the answer
references: List of question-answers dictionaries with the following key-values:
- 'id': id of the question-answer pair (see above),
- 'answers': a Dict in the SQuAD dataset format
{
'text': list of possible texts for the answer, as a list of strings
'answer_start': list of start positions for the answer, as a list of ints
}
Note that answer_start values are not taken into account to compute the metric.
Returns:
'exact_match': Exact match (the normalized answer exactly match the gold answer)
'f1': The F-score of predicted tokens versus the gold answer
Examples:
>>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22'}]
>>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}]
>>> squad_metric = datasets.load_metric("squad")
>>> results = squad_metric.compute(predictions=predictions, references=references)
>>> print(results)
{'exact_match': 100.0, 'f1': 100.0}
"""
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Squad(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": {"id": datasets.Value("string"), "prediction_text": datasets.Value("string")},
"references": {
"id": datasets.Value("string"),
"answers": datasets.features.Sequence(
{
"text": datasets.Value("string"),
"answer_start": datasets.Value("int32"),
}
),
},
}
),
codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
)
def _compute(self, predictions, references):
pred_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions}
dataset = [
{
"paragraphs": [
{
"qas": [
{
"answers": [{"text": answer_text} for answer_text in ref["answers"]["text"]],
"id": ref["id"],
}
for ref in references
]
}
]
}
]
score = evaluate(dataset=dataset, predictions=pred_dict)
return score

View File

@ -0,0 +1,330 @@
"""Official evaluation script for SQuAD version 2.0.
In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
import argparse
import collections
import json
import os
import re
import string
import sys
import numpy as np
OPTS = None
def parse_args():
parser = argparse.ArgumentParser("Official evaluation script for SQuAD version 2.0.")
parser.add_argument("data_file", metavar="data.json", help="Input data JSON file.")
parser.add_argument("pred_file", metavar="pred.json", help="Model predictions.")
parser.add_argument(
"--out-file", "-o", metavar="eval.json", help="Write accuracy metrics to file (default is stdout)."
)
parser.add_argument(
"--na-prob-file", "-n", metavar="na_prob.json", help="Model estimates of probability of no answer."
)
parser.add_argument(
"--na-prob-thresh",
"-t",
type=float,
default=1.0,
help='Predict "" if no-answer probability exceeds this (default = 1.0).',
)
parser.add_argument(
"--out-image-dir", "-p", metavar="out_images", default=None, help="Save precision-recall curves to directory."
)
parser.add_argument("--verbose", "-v", action="store_true")
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def make_qid_to_has_ans(dataset):
qid_to_has_ans = {}
for article in dataset:
for p in article["paragraphs"]:
for qa in p["qas"]:
qid_to_has_ans[qa["id"]] = bool(qa["answers"]["text"])
return qid_to_has_ans
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()
def compute_exact(a_gold, a_pred):
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
def compute_f1(a_gold, a_pred):
gold_toks = get_tokens(a_gold)
pred_toks = get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def get_raw_scores(dataset, preds):
exact_scores = {}
f1_scores = {}
for article in dataset:
for p in article["paragraphs"]:
for qa in p["qas"]:
qid = qa["id"]
if qid not in preds:
print(f"Missing prediction for {qid}")
continue
a_pred = preds[qid]
gold_answers = [t for t in qa["answers"]["text"] if normalize_answer(t)]
if not gold_answers:
# For unanswerable questions, only correct answer is empty string
gold_answers = [""]
if a_pred != "":
exact_scores[qid] = 0
f1_scores[qid] = 0
else:
exact_scores[qid] = 1
f1_scores[qid] = 1
else:
# Take max over all gold answers
exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
return exact_scores, f1_scores
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
new_scores = {}
for qid, s in scores.items():
pred_na = na_probs[qid] > na_prob_thresh
if pred_na:
new_scores[qid] = float(not qid_to_has_ans[qid])
else:
new_scores[qid] = s
return new_scores
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
if not qid_list:
total = len(exact_scores)
return collections.OrderedDict(
[
("exact", 100.0 * sum(exact_scores.values()) / total),
("f1", 100.0 * sum(f1_scores.values()) / total),
("total", total),
]
)
else:
total = len(qid_list)
return collections.OrderedDict(
[
("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
("total", total),
]
)
def merge_eval(main_eval, new_eval, prefix):
for k in new_eval:
main_eval[f"{prefix}_{k}"] = new_eval[k]
def plot_pr_curve(precisions, recalls, out_image, title):
plt.step(recalls, precisions, color="b", alpha=0.2, where="post")
plt.fill_between(recalls, precisions, step="post", alpha=0.2, color="b")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.xlim([0.0, 1.05])
plt.ylim([0.0, 1.05])
plt.title(title)
plt.savefig(out_image)
plt.clf()
def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None):
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
true_pos = 0.0
cur_p = 1.0
cur_r = 0.0
precisions = [1.0]
recalls = [0.0]
avg_prec = 0.0
for i, qid in enumerate(qid_list):
if qid_to_has_ans[qid]:
true_pos += scores[qid]
cur_p = true_pos / float(i + 1)
cur_r = true_pos / float(num_true_pos)
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]:
# i.e., if we can put a threshold after this point
avg_prec += cur_p * (cur_r - recalls[-1])
precisions.append(cur_p)
recalls.append(cur_r)
if out_image:
plot_pr_curve(precisions, recalls, out_image, title)
return {"ap": 100.0 * avg_prec}
def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, out_image_dir):
if out_image_dir and not os.path.exists(out_image_dir):
os.makedirs(out_image_dir)
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
if num_true_pos == 0:
return
pr_exact = make_precision_recall_eval(
exact_raw,
na_probs,
num_true_pos,
qid_to_has_ans,
out_image=os.path.join(out_image_dir, "pr_exact.png"),
title="Precision-Recall curve for Exact Match score",
)
pr_f1 = make_precision_recall_eval(
f1_raw,
na_probs,
num_true_pos,
qid_to_has_ans,
out_image=os.path.join(out_image_dir, "pr_f1.png"),
title="Precision-Recall curve for F1 score",
)
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
pr_oracle = make_precision_recall_eval(
oracle_scores,
na_probs,
num_true_pos,
qid_to_has_ans,
out_image=os.path.join(out_image_dir, "pr_oracle.png"),
title="Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)",
)
merge_eval(main_eval, pr_exact, "pr_exact")
merge_eval(main_eval, pr_f1, "pr_f1")
merge_eval(main_eval, pr_oracle, "pr_oracle")
def histogram_na_prob(na_probs, qid_list, image_dir, name):
if not qid_list:
return
x = [na_probs[k] for k in qid_list]
weights = np.ones_like(x) / float(len(x))
plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
plt.xlabel("Model probability of no-answer")
plt.ylabel("Proportion of dataset")
plt.title(f"Histogram of no-answer probability: {name}")
plt.savefig(os.path.join(image_dir, f"na_prob_hist_{name}.png"))
plt.clf()
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = 0.0
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
for i, qid in enumerate(qid_list):
if qid not in scores:
continue
if qid_to_has_ans[qid]:
diff = scores[qid]
else:
if preds[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = na_probs[qid]
return 100.0 * best_score / len(scores), best_thresh
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
main_eval["best_exact"] = best_exact
main_eval["best_exact_thresh"] = exact_thresh
main_eval["best_f1"] = best_f1
main_eval["best_f1_thresh"] = f1_thresh
def main():
with open(OPTS.data_file) as f:
dataset_json = json.load(f)
dataset = dataset_json["data"]
with open(OPTS.pred_file) as f:
preds = json.load(f)
if OPTS.na_prob_file:
with open(OPTS.na_prob_file) as f:
na_probs = json.load(f)
else:
na_probs = {k: 0.0 for k in preds}
qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = get_raw_scores(dataset, preds)
exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh)
f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh)
out_eval = make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
merge_eval(out_eval, has_ans_eval, "HasAns")
if no_ans_qids:
no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
merge_eval(out_eval, no_ans_eval, "NoAns")
if OPTS.na_prob_file:
find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
if OPTS.na_prob_file and OPTS.out_image_dir:
run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, OPTS.out_image_dir)
histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, "hasAns")
histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, "noAns")
if OPTS.out_file:
with open(OPTS.out_file, "w") as f:
json.dump(out_eval, f)
else:
print(json.dumps(out_eval, indent=2))
if __name__ == "__main__":
OPTS = parse_args()
if OPTS.out_image_dir:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
main()

View File

@ -0,0 +1,139 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" SQuAD v2 metric. """
import datasets
from .evaluate import (
apply_no_ans_threshold,
find_all_best_thresh,
get_raw_scores,
make_eval_dict,
make_qid_to_has_ans,
merge_eval,
)
_CITATION = """\
@inproceedings{Rajpurkar2016SQuAD10,
title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text},
author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang},
booktitle={EMNLP},
year={2016}
}
"""
_DESCRIPTION = """
This metric wrap the official scoring script for version 2 of the Stanford Question
Answering Dataset (SQuAD).
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by
crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span,
from the corresponding reading passage, or the question might be unanswerable.
SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable questions
written adversarially by crowdworkers to look similar to answerable ones.
To do well on SQuAD2.0, systems must not only answer questions when possible, but also
determine when no answer is supported by the paragraph and abstain from answering.
"""
_KWARGS_DESCRIPTION = """
Computes SQuAD v2 scores (F1 and EM).
Args:
predictions: List of triple for question-answers to score with the following elements:
- the question-answer 'id' field as given in the references (see below)
- the text of the answer
- the probability that the question has no answer
references: List of question-answers dictionaries with the following key-values:
- 'id': id of the question-answer pair (see above),
- 'answers': a list of Dict {'text': text of the answer as a string}
no_answer_threshold: float
Probability threshold to decide that a question has no answer.
Returns:
'exact': Exact match (the normalized answer exactly match the gold answer)
'f1': The F-score of predicted tokens versus the gold answer
'total': Number of score considered
'HasAns_exact': Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': The F-score of predicted tokens versus the gold answer
'HasAns_total': Number of score considered
'NoAns_exact': Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': The F-score of predicted tokens versus the gold answer
'NoAns_total': Number of score considered
'best_exact': Best exact match (with varying threshold)
'best_exact_thresh': No-answer probability threshold associated to the best exact match
'best_f1': Best F1 (with varying threshold)
'best_f1_thresh': No-answer probability threshold associated to the best F1
Examples:
>>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22', 'no_answer_probability': 0.}]
>>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}]
>>> squad_v2_metric = datasets.load_metric("squad_v2")
>>> results = squad_v2_metric.compute(predictions=predictions, references=references)
>>> print(results)
{'exact': 100.0, 'f1': 100.0, 'total': 1, 'HasAns_exact': 100.0, 'HasAns_f1': 100.0, 'HasAns_total': 1, 'best_exact': 100.0, 'best_exact_thresh': 0.0, 'best_f1': 100.0, 'best_f1_thresh': 0.0}
"""
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class SquadV2Local(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": {
"id": datasets.Value("string"),
"prediction_text": datasets.Value("string"),
"no_answer_probability": datasets.Value("float32"),
},
"references": {
"id": datasets.Value("string"),
"answers": datasets.features.Sequence(
{"text": datasets.Value("string"), "answer_start": datasets.Value("int32")}
),
},
}
),
codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
)
def _compute(self, predictions, references, no_answer_threshold=1.0):
# no_answer_probabilities = dict((p["id"], p["no_answer_probability"]) for p in predictions)
dataset = [{"paragraphs": [{"qas": references}]}]
predictions = dict((p["id"], p["prediction_text"]) for p in predictions)
# qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
# has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
# no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = get_raw_scores(dataset, predictions)
out_eval = make_eval_dict(exact_raw, f1_raw)
#
#
# exact_thresh = apply_no_ans_threshold(exact_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold)
# f1_thresh = apply_no_ans_threshold(f1_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold)
# out_eval = make_eval_dict(exact_thresh, f1_thresh)
#
# if has_ans_qids:
# has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
# merge_eval(out_eval, has_ans_eval, "HasAns")
# if no_ans_qids:
# no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
# merge_eval(out_eval, no_ans_eval, "NoAns")
# find_all_best_thresh(out_eval, predictions, exact_raw, f1_raw, no_answer_probabilities, qid_to_has_ans)
return dict(out_eval)

View File

@ -0,0 +1,192 @@
import sys
sys.path.append('../data_process')
from typing import List, Optional, Tuple
from QAInput import SimpleQAInput as QAInput
#from QAInput import QAInputNoPrompt as QAInput
def preprocess_sqaud_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_sqaud_abstractive_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_boolq_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
question_column, context_column, answer_column = 'question', 'passage', 'answer'
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
targets = [str(ans) for ans in answers] #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
# print(inputs,targets)
return inputs, targets
def preprocess_boolq_batch_pretrain(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0].capitalize() if len(answer["text"]) > 0 else "" for answer in answers]
# print(inputs,targets)
return inputs, targets
def preprocess_narrativeqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = [exp['summary']['text'] for exp in examples['document']]
questions = [exp['text'] for exp in examples['question']]
answers = [ans[0]['text'] for ans in examples['answers']]
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = answers #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_narrativeqa_batch_pretrain(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_drop_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = examples['passage']
questions = examples['question']
answers = examples['answers']
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_race_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = examples['article']
questions = examples['question']
all_options = examples['options']
answers = examples['answer']
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}; D. {options[3]}' for options in all_options]
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3 }
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
return inputs, targets
def preprocess_newsqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
# inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_ropes_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples[question_column]
backgrounds = examples["background"]
situations = examples["situation"]
answers = examples[answer_column]
inputs = [QAInput.qg_input_extractive_qa(" ".join([background, situation]), question) for question, background, situation in zip(questions, backgrounds, situations)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets
def preprocess_openbookqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
questions = examples['question_stem']
all_options = examples['choices']
answers = examples['answerKey']
options_texts = [f"options: A. {options['text'][0]}; B. {options['text'][1]}; C. {options['text'][2]}; D. {options['text'][3]}" for options in all_options]
inputs = [QAInput.qg_input_multirc("", question, ops) for question, ops in zip(questions, options_texts)]
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
targets = [options['text'][ans_map[answer]] for options, answer in zip(all_options, answers)]
return inputs, targets
def preprocess_social_iqa_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = examples['article']
questions = examples['question']
all_options = examples['options']
answers = examples['answer']
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
ans_map = {'A': 0, 'B': 1, 'C': 2,}
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
return inputs, targets
def preprocess_dream_batch(
examples,
question_column: str,
context_column: str,
answer_column: str,
) -> Tuple[List[str], List[str]]:
contexts = [" ".join(dialogue) for dialogue in examples['dialogue']]
questions = examples['question']
all_options = examples['choice']
answers = examples['answer']
answer_idxs = [options.index(answer) for answer, options in zip(answers, all_options)]
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
targets = answers
return inputs, targets

View File

@ -0,0 +1,248 @@
import itertools
import json
import os
import csv
import errno
import random
from random import shuffle
from typing import List
#import spacy
from tqdm import tqdm
import codecs
#import nltk
import glob
import xml.etree.ElementTree as ET
from datasets import load_dataset
#import statistic
from QAInput import StructuralQAInput, SimpleQAInput
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
M2M100Tokenizer,
MBart50Tokenizer,
EvalPrediction,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
set_seed,
)
tokenizer = AutoTokenizer.from_pretrained("./t5-base")
#nltk.download("stopwords")
#from nltk.corpus import stopwords
task2id = {"wikisql":0,"squad":1,"srl":2,"sst":3,"woz.en":4,"iwslt.en.de":5,"cnn_dailymail":6,"multinli.in.out":7,"zre":8}
task2format={"wikisql":1,"squad":0,"srl":0,"sst":2,"woz.en":1,"iwslt.en.de":0,"cnn_dailymail":1,"multinli.in.out":2,"zre":0}
#STOPWORDS = stopwords.words("english")
#STOPWORDS = [stopword + " " for stopword in STOPWORDS]
#nlp = spacy.load("en_core_web_sm")
def chuli_example(examples,fname):
cur_dataset = fname
answers_start = []
if fname in ["none"]:
addition_answers = open(fname+"_answers.json",'r')
faddition = json.load(addition_answers)
for item in faddition:
answers_start.append({"text":item,"answer_start":[0]*len(item)})
else:
for item in examples["answer"]:
answers_start.append({"text":item,"answer_start":[0]*len(item)})
print(answers_start[:10])
examples = examples.add_column("answers",answers_start)
return examples
def preprocess_plain(
examples,
question_column: str,
context_column: str,
answer_column:str):
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [SimpleQAInput.qg_input(question, context) for question,context in zip(questions,contexts)]
answers = [_[0] for _ in answers]
return inputs,answers
def preprocess_proqa(
examples,
question_column: str,
context_column: str,
answer_column:str):
questions = examples[question_column]
contexts = examples[context_column]
answers = examples[answer_column]
inputs = [StructuralQAInput.qg_input(question, context) for question,context in zip(questions,contexts)]
answers = [_[0] for _ in answers]
return inputs,answers
def preprocess_function(examples):
preprocess_fn = preprocess_proqa#dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, "question","context","answer")
model_inputs = tokenizer(inputs, max_length=1024, padding=False, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=128, padding=False, truncation=True)
model_inputs["labels"] = labels["input_ids"]
gen_prompt_ids = [-(i+1) for i in range(1520,1520+10)]
format_id = task2format[dataset_name]
format_prompt_id_start = 300
format_prompt_ids = [-(i+1) for i in range(format_prompt_id_start + format_id * 10,
format_prompt_id_start + (format_id + 1) * 10)]
task_id = task2id[dataset_name]
task_prompt_id_start = 0
task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * 20,
task_prompt_id_start + (task_id + 1) * 20)]
domain_prompt_id_start = 20*30
domain_prompt_number = 20
domain_prompt_ids = [- (i + 1) for i in range(domain_prompt_id_start,
domain_prompt_id_start + 20)]*5
input_ids = copy.deepcopy(
[gen_prompt_ids+format_prompt_ids+task_prompt_ids + domain_prompt_ids+input_ids for input_ids in model_inputs['input_ids']])
model_inputs['input_ids'] = input_ids # [format_prompt_ids+input_ids for input_ids in model_inputs['input_ids']]
model_inputs['attention_mask'] = [[1] * 140 + attention_mask for attention_mask in
model_inputs['attention_mask']]
return model_inputs
def preprocess_function_valid(examples):
preprocess_fn = preprocess_proqa#dataset_name_to_func(data_args.dataset_name)
inputs, targets = preprocess_fn(examples, "question","context","answer")
model_inputs = tokenizer(inputs, max_length=1024, padding=False, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=128, padding=False, truncation=True)
model_inputs["example_id"] = []
for i in range(len(model_inputs["input_ids"])):
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = i #sample_mapping[i]
model_inputs["example_id"].append(examples["id"][sample_index])
gen_prompt_ids = [-(i+1) for i in range(1520,1520+10)]
format_id = task2format[dataset_name]
format_prompt_id_start = 300
format_prompt_ids = [-(i+1) for i in range(format_prompt_id_start + format_id * 10,
format_prompt_id_start + (format_id + 1) * 10)]
task_id = task2id[dataset_name]
task_prompt_id_start = 0
task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * 20,
task_prompt_id_start + (task_id + 1) * 20)]
domain_prompt_id_start = 20*30
domain_prompt_number = 20
domain_prompt_ids = [- (i + 1) for i in range(domain_prompt_id_start,
domain_prompt_id_start + 20)]*5
input_ids = copy.deepcopy(
[gen_prompt_ids+format_prompt_ids+task_prompt_ids + domain_prompt_ids+input_ids for input_ids in model_inputs['input_ids']])
model_inputs['input_ids'] = input_ids # [format_prompt_ids+input_ids for input_ids in model_inputs['input_ids']]
model_inputs['attention_mask'] = [[1] * 140 + attention_mask for attention_mask in
model_inputs['attention_mask']]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def add_id(example,index):
example.update({'id':index})
return example
def prep(raw_ds,fname):
if True:
column_names = raw_ds.column_names
global dataset_name
dataset_name = fname
train_dataset = raw_ds.map(
preprocess_function,
batched=True,
num_proc=4,
remove_columns=column_names,
load_from_cache_file=True,
desc="Running tokenizer on train dataset",
)
train_dataset = train_dataset.add_column("id",range(len(train_dataset)))
train_dataset.save_to_disk("./ours/{}-train.hf".format(fname))
import copy
def prep_valid(raw_ds,fname):
global dataset_name
dataset_name = fname
eval_examples = copy.deepcopy(raw_ds)
eval_examples = chuli_example(eval_examples,fname)
column_names = raw_ds.column_names
if 'id' not in eval_examples.features.keys():
eval_examples = eval_examples.map(add_id,with_indices=True)
if True:
eval_dataset = raw_ds.map(add_id,with_indices=True)
eval_dataset = eval_dataset.map(
preprocess_function_valid,
batched=True,
num_proc=4,
remove_columns=column_names,
load_from_cache_file=True,
desc="Running tokenizer on validation dataset",
)
eval_dataset.save_to_disk("./ours/{}-eval.hf".format(fname))
eval_examples.save_to_disk("./ours/{}-evalex.hf".format(fname))
from collections import Counter
import json
def main():
for item in ["wikisql","squad","srl","sst","woz.en","iwslt.en.de","cnn_dailymail","multinli.in.out","zre"]:
print("current_dataset:",item)
print("Loading......")
data_files = {}
data_files["validation"] = item+"-eval.jsonl"
test_dataset = load_dataset("json", data_files=data_files)["validation"]
if not item in ["cnn_dailymail","multinli.in.out","zre"]:
data_files = {}
data_files["train"] = item+"-train.jsonl"
train_dataset = load_dataset("json", data_files=data_files)["train"]
print("Loaded.")
if not item in ["cnn_dailymail","multinli.in.out","zre"]:
prep(train_dataset,item)
print("preped")
prep_valid(test_dataset,item)
print("valid prepred")
main()

477
diana/metrics.py Normal file
View File

@ -0,0 +1,477 @@
import collections
import string
import re
import numpy as np
import json
AGG_OPS = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
COND_OPS = ['=', '>', '<', 'OP']
COND_OPS = ['=', '>', '<']
#from pyrouge import Rouge155
#from multiprocessing import Pool, cpu_count
#from contextlib import closing
from datasets import load_metric
def compute_rouge_scores(summs, refs, splitchar='.', options=None, parallel=True):
assert len(summs) == len(refs)
options = [
'-a', # evaluate all systems
'-c', 95, # confidence interval
'-m', # use Porter stemmer
'-n', 2, # max-ngram
'-w', 1.3, # weight (weighting factor for WLCS)
]
rr = Rouge(options=options)
rouge_args = []
for summ, ref in zip(summs, refs):
letter = "A"
ref_dict = {}
for r in ref:
ref_dict[letter] = [x for x in split_sentences(r, splitchar) if len(x) > 0]
letter = chr(ord(letter) + 1)
s = [x for x in split_sentences(summ, splitchar) if len(x) > 0]
rouge_args.append((s, ref_dict))
if parallel:
with closing(Pool(cpu_count()//2)) as pool:
rouge_scores = pool.starmap(rr.score_summary, rouge_args)
else:
rouge_scores = []
for s, a in rouge_args:
rouge_scores.append(rr.score_summary(s, ref_dict))
return rouge_scores
def computeROUGE(greedy, answer):
rouges = compute_rouge_scores(greedy, answer)
if len(rouges) > 0:
avg_rouges = {}
for key in rouges[0].keys():
avg_rouges[key] = sum(
[r.get(key, 0.0) for r in rouges]) / len(rouges) * 100
else:
avg_rouges = None
return avg_rouges
def lcsstr(string1, string2):
answer = 0
len1, len2 = len(string1), len(string2)
for i in range(len1):
match = 0
for j in range(len2):
if (i + j < len1 and string1[i + j] == string2[j]):
match += 1
if match > answer:
answer = match
else:
match = 0
return answer
def lcs(X, Y):
m = len(X)
n = len(Y)
L = [[None]*(n + 1) for i in range(m + 1)]
for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0 :
L[i][j] = 0
elif X[i-1] == Y[j-1]:
L[i][j] = L[i-1][j-1]+1
else:
L[i][j] = max(L[i-1][j], L[i][j-1])
return L[m][n] / max(m, n) + lcsstr(X, Y) / 1e4 - min(m, n) / 1e8
def normalize_text(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
# print(text)
# print("after:",' '.join(text.split( )))
return ' '.join(text.split( ))
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
# print(white_space_fix(remove_articles(remove_punc(lower(s)))))
return white_space_fix(remove_articles(remove_punc(lower(s))))
def to_delta_state(line):
delta_state = {'inform': {}, 'request': {}}
try:
if line.lower() == 'none' or line.strip() == '' or line.strip() == ';':
return delta_state
inform, request = [[y.strip() for y in x.strip().split(',')] for x in line.split(';')]
inform_pairs = {}
for i in inform:
try:
k, v = i.split(':')
inform_pairs[k.strip()] = v.strip()
except:
pass
delta_state = {'inform': inform_pairs, 'request': request}
except:
pass
finally:
return delta_state
def update_state(state, delta):
for act, slot in delta.items():
state[act] = slot
return state
def dict_cmp(d1, d2):
def cmp(a, b):
for k1, v1 in a.items():
if k1 not in b:
return False
else:
if v1 != b[k1]:
return False
return True
return cmp(d1, d2) and cmp(d2, d1)
def to_lf(s, table):
# print("input:",s)
aggs = [y.lower() for y in AGG_OPS]
agg_to_idx = {x: i for i, x in enumerate(aggs)}
conditionals = [y.lower() for y in COND_OPS]
headers_unsorted = [(y.lower(), i) for i, y in enumerate(table['header'])]
headers = [(y.lower(), i) for i, y in enumerate(table['header'])]
headers.sort(reverse=True, key=lambda x: len(x[0]))
condition_s, conds = None, []
if 'where' in s:
s, condition_s = s.split('where', 1)
# print("s,cond",s,condition_s)
s = ' '.join(s.split()[1:-2])
s_no_agg = ' '.join(s.split()[1:])
# print(s,s_no_agg,"agg")
sel, agg = None, 0
lcss, idxs = [], []
for col, idx in headers:
lcss.append(lcs(col, s))
lcss.append(lcs(col, s_no_agg))
idxs.append(idx)
lcss = np.array(lcss)
max_id = np.argmax(lcss)
sel = idxs[max_id // 2]
if max_id % 2 == 1: # with agg
agg = agg_to_idx[s.split()[0]]
full_conditions = []
if not condition_s is None:
pattern = '|'.join(COND_OPS)
split_conds_raw = re.split(pattern, condition_s)
split_conds_raw = [conds.strip() for conds in split_conds_raw]
split_conds = [split_conds_raw[0]]
for i in range(1, len(split_conds_raw)-1):
split_conds.extend(re.split('and', split_conds_raw[i]))
split_conds += [split_conds_raw[-1]]
for i in range(0, len(split_conds), 2):
cur_s = split_conds[i]
lcss = []
for col in headers:
lcss.append(lcs(col[0], cur_s))
max_id = np.argmax(np.array(lcss))
split_conds[i] = headers[max_id][0]
for i, m in enumerate(re.finditer(pattern, condition_s)):
split_conds[2*i] = split_conds[2*i] + ' ' + m.group()
split_conds = [' '.join(split_conds[2*i:2*i+2]) for i in range(len(split_conds)//2)]
condition_s = ' and '.join(split_conds)
condition_s = ' ' + condition_s + ' '
for idx, col in enumerate(headers):
condition_s = condition_s.replace(' ' + col[0] + ' ', ' Col{} '.format(col[1]))
condition_s = condition_s.strip()
for idx, col in enumerate(conditionals):
new_s = []
for t in condition_s.split():
if t == col:
new_s.append('Cond{}'.format(idx))
else:
new_s.append(t)
condition_s = ' '.join(new_s)
s = condition_s
conds = re.split('(Col\d+ Cond\d+)', s)
if len(conds) == 0:
conds = [s]
conds = [x for x in conds if len(x.strip()) > 0]
full_conditions = []
for i, x in enumerate(conds):
if i % 2 == 0:
x = x.split()
col_num = int(x[0].replace('Col', ''))
opp_num = int(x[1].replace('Cond', ''))
full_conditions.append([col_num, opp_num])
else:
x = x.split()
if x[-1] == 'and':
x = x[:-1]
x = ' '.join(x)
if 'Col' in x:
new_x = []
for t in x.split():
if 'Col' in t:
idx = int(t.replace('Col', ''))
t = headers_unsorted[idx][0]
new_x.append(t)
x = new_x
x = ' '.join(x)
if 'Cond' in x:
new_x = []
for t in x.split():
if 'Cond' in t:
idx = int(t.replace('Cond', ''))
t = conditionals[idx]
new_x.append(t)
x = new_x
x = ' '.join(x)
full_conditions[-1].append(x.replace(' ', ''))
logical_form = {'sel': sel, 'conds': full_conditions, 'agg': agg}
return logical_form
def computeLFEM(greedy, answer):
count = 0
correct = 0
text_answers = []
for idx, (g, ex) in enumerate(zip(greedy, answer)):
count += 1
text_answers.append([ex['answer'].lower()])
try:
gt = ex['sql']
conds = gt['conds']
lower_conds = []
for c in conds:
lc = c
lc[2] = str(lc[2]).lower().replace(' ', '')
lower_conds.append(lc)
gt['conds'] = lower_conds
# print("gt_answer:",ex['answer'].lower())
lf = to_lf(g, ex['table'])
# print(lf,"lf")
# print(gt,"gt")
correct += lf == gt
except Exception as e:
continue
return (correct / count) * 100, text_answers
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def f1_score(prediction, ground_truth):
prediction_tokens = prediction.split()
ground_truth_tokens = ground_truth.split()
common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match(prediction, ground_truth):
return prediction == ground_truth
def computeF1(outputs, targets):
return sum([metric_max_over_ground_truths(f1_score, o, t) for o, t in zip(outputs, targets)]) / len(outputs) * 100
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for idx, ground_truth in enumerate(ground_truths):
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def computeEM(outputs, targets):
outs = [metric_max_over_ground_truths(exact_match, o, t) for o, t in zip(outputs, targets)]
return sum(outs) / len(outputs) * 100
def fix_buggy_characters(str):
return re.sub("[{}^\\\\`\u2047<]", " ", str)
def replace_punctuation(str):
return str.replace("\"", "").replace("'", "")
def score_string_similarity(str1, str2):
if str1 == str2:
return 3.0 # Better than perfect token match
str1 = fix_buggy_characters(replace_punctuation(str1))
str2 = fix_buggy_characters(replace_punctuation(str2))
if str1 == str2:
return 2.0
if " " in str1 or " " in str2:
str1_split = str1.split(" ")
str2_split = str2.split(" ")
overlap = list(set(str1_split) & set(str2_split))
return len(overlap) / max(len(str1_split), len(str2_split))
else:
if str1 == str2:
return 1.0
else:
return 0.0
def computeEMtri(outputs, targets):
options = ["entailment","neutral","contradiction"]
preds = []
for o in outputs:
scores = [score_string_similarity(opt, o) for opt in options]
max_idx = np.argmax(scores)
preds.append(options[max_idx])
print(preds)
print(targets)
outs = [metric_max_over_ground_truths(exact_match, o, t) for o, t in zip(preds, targets)]
return sum(outs) / len(outputs) * 100
def score(answer, gold):
if len(gold) > 0:
gold = set.union(*[simplify(g) for g in gold])
answer = simplify(answer)
tp, tn, sys_pos, real_pos = 0, 0, 0, 0
if answer == gold:
if not ('unanswerable' in gold and len(gold) == 1):
tp += 1
else:
tn += 1
if not ('unanswerable' in answer and len(answer) == 1):
sys_pos += 1
if not ('unanswerable' in gold and len(gold) == 1):
real_pos += 1
return np.array([tp, tn, sys_pos, real_pos])
def simplify(answer):
return set(''.join(c for c in t if c not in string.punctuation) for t in answer.strip().lower().split()) - {'the', 'a', 'an', 'and', ''}
def computeCF1(greedy, answer):
scores = np.zeros(4)
for g, a in zip(greedy, answer):
scores += score(g, a)
tp, tn, sys_pos, real_pos = scores.tolist()
total = len(answer)
if tp == 0:
p = r = f = 0.0
else:
p = tp / float(sys_pos)
r = tp / float(real_pos)
f = 2 * p * r / (p + r)
return f * 100, p * 100, r * 100
def computeDialogue(greedy, answer):
examples = []
for idx, (g, a) in enumerate(zip(greedy, answer)):
examples.append((a[0], g, a[1], idx))
#examples.sort()
turn_request_positives = 0
turn_goal_positives = 0
joint_goal_positives = 0
ldt = None
for ex in examples:
if ldt is None or ldt.split('_')[:-1] != ex[0].split('_')[:-1]:
state, answer_state = {}, {}
ldt = ex[0]
delta_state = to_delta_state(ex[1])
answer_delta_state = to_delta_state(ex[2])
state = update_state(state, delta_state['inform'])
answer_state = update_state(answer_state, answer_delta_state['inform'])
if dict_cmp(state, answer_state):
joint_goal_positives += 1
if delta_state['request'] == answer_delta_state['request']:
turn_request_positives += 1
if dict_cmp(delta_state['inform'], answer_delta_state['inform']):
turn_goal_positives += 1
joint_goal_em = joint_goal_positives / len(examples) * 100
turn_request_em = turn_request_positives / len(examples) * 100
turn_goal_em = turn_goal_positives / len(examples) * 100
answer = [(x[-1], x[-2]) for x in examples]
#answer.sort()
answer = [[x[1]] for x in answer]
return joint_goal_em, turn_request_em, turn_goal_em, answer
def compute_metrics(data, rouge=False, bleu=False, corpus_f1=False, logical_form=False, dialogue=False,tri=False):
if rouge:
metric_func = load_metric("metric/rouge_local/rouge_metric.py")
metrics = metric_func.compute(predictions=data.predictions, references=data.label_ids)
metric_keys = ["rougeL"]
metric_values = metrics["rougeL"]
metric_dict = collections.OrderedDict(list(zip(metric_keys, [metric_values])))
return metric_dict
greedy = data.predictions
answer = data.label_ids
greedy = [_["prediction_text"] for _ in greedy]
# greedy = [_["answers"]["text"][0].lower() for _ in answer]
answer = [_["answers"]["text"] for _ in answer]
if dialogue:
addition_answers = open("./oursdeca/"+"woz.en_answers.json",'r')
answer = json.load(addition_answers)
if logical_form:
addition_answers = open("./oursdeca/"+"wikisql_answers.json",'r')
answer = json.load(addition_answers)
metric_keys = []
metric_values = []
if logical_form:
lfem, answer = computeLFEM(greedy, answer)
metric_keys += ['lfem']
metric_values += [lfem]
if tri:
em = computeEMtri(greedy,answer)
else:
em = computeEM(greedy, answer)
print(greedy[:20])
print(answer[:20])
metric_keys.append('em')
metric_values.append(em)
norm_greedy = [normalize_text(g) for g in greedy]
norm_answer = [[normalize_text(a) for a in ans] for ans in answer]
nf1 = computeF1(norm_greedy, norm_answer)
nem = computeEM(norm_greedy, norm_answer)
metric_keys.extend(['nf1', 'nem'])
metric_values.extend([nf1, nem])
if rouge:
rouge = computeROUGE(greedy, answer)
metric_keys += ['rouge1', 'rouge2', 'rougeL', 'avg_rouge']
avg_rouge = (rouge['rouge_1_f_score'] + rouge['rouge_2_f_score'] + rouge['rouge_l_f_score']) / 3
metric_values += [rouge['rouge_1_f_score'], rouge['rouge_2_f_score'], rouge['rouge_l_f_score'], avg_rouge]
if corpus_f1:
corpus_f1, precision, recall = computeCF1(norm_greedy, norm_answer)
metric_keys += ['corpus_f1', 'precision', 'recall']
metric_values += [corpus_f1, precision, recall]
if dialogue:
joint_goal_em, request_em, turn_goal_em, answer = computeDialogue(greedy, answer)
avg_dialogue = (joint_goal_em + request_em) / 2
metric_keys += ['joint_goal_em', 'turn_request_em', 'turn_goal_em', 'avg_dialogue']
metric_values += [joint_goal_em, request_em, turn_goal_em, avg_dialogue]
metric_dict = collections.OrderedDict(list(zip(metric_keys, metric_values)))
return metric_dict

2406
diana/models/metadeca.py Executable file

File diff suppressed because it is too large Load Diff

2386
diana/models/metadecanometa.py Executable file

File diff suppressed because it is too large Load Diff

2349
diana/models/metadecanotask.py Executable file

File diff suppressed because it is too large Load Diff

2409
diana/models/metat5.py Executable file

File diff suppressed because it is too large Load Diff

2390
diana/models/metat5nometa.py Executable file

File diff suppressed because it is too large Load Diff

2351
diana/models/metat5notask.py Executable file

File diff suppressed because it is too large Load Diff

61
diana/models/utils.py Normal file
View File

@ -0,0 +1,61 @@
import torch.nn.functional as F
from torch import nn
import torch
import copy
#adaptive decision boundary
#[code]https://github.com/thuiar/Adaptive-Decision-Boundary
#[paper]https://www.aaai.org/AAAI21Papers/AAAI-9723.ZhangH.pdf
def euclidean_metric(a, b):
n = a.shape[0]
m = b.shape[0]
a = a.unsqueeze(1).expand(n, m, -1)
b = b.unsqueeze(0).expand(n, m, -1)
logits = -((a - b)**2).sum(dim=2)
return logits
def cosine_metric(a,b):
n = a.shape[0]
m = b.shape[0]
a = a.unsqueeze(1).expand(n, m, -1)
b = b.unsqueeze(0).expand(n, m, -1)
logits = (a*b).sum(dim=2)
# logits = -logits+1
return logits
class BoundaryLoss(nn.Module):
def __init__(self, num_labels=10, feat_dim=2):
super(BoundaryLoss, self).__init__()
self.num_labels = num_labels
self.feat_dim = feat_dim
self.delta = nn.Parameter(torch.ones(20).cuda())
nn.init.normal_(self.delta)
def forward(self, pooled_output, centroids_, labels,group=0):
centroids = copy.deepcopy(centroids_.detach())
logits = cosine_metric(pooled_output, centroids)
probs, preds = F.softmax(logits.detach(), dim=1).max(dim=1)
delta = F.softplus(self.delta)
#c = torch.Tensor([[0.0]*768]*(centroids[labels]).size(0)).to(pooled_output.device)#
# print(self.delta.size())
# print("labels",labels.size())
c = centroids[labels]
c = c/torch.norm(c, p=2, dim=1, keepdim=True)
# c.requires_grad = False
d = delta[labels]
x = pooled_output
euc_dis = 1-((c*x).sum(dim=1))#torch.norm(x - c,2, 1).view(-1)
pos_mask = (euc_dis > d).type(torch.cuda.FloatTensor)
neg_mask = (euc_dis < d).type(torch.cuda.FloatTensor)
pos_loss = (euc_dis - d) * pos_mask
neg_loss = (d - euc_dis) * neg_mask
loss = pos_loss.mean() + neg_loss.mean()
return loss, delta

56
diana/r.txt Normal file
View File

@ -0,0 +1,56 @@
absl-py==1.0.0
aiohttp==3.8.1
aiosignal==1.2.0
async-timeout==4.0.2
asynctest==0.13.0
attrs==21.4.0
certifi==2022.5.18.1
charset-normalizer==2.0.12
click==8.1.3
datasets==2.2.2
dill==0.3.4
filelock==3.7.0
frozenlist==1.3.0
fsspec==2022.5.0
future==0.18.2
fuzzywuzzy==0.18.0
huggingface-hub==0.7.0
idna==3.3
importlib-metadata==4.11.4
joblib==1.1.0
multidict==6.0.2
multiprocess==0.70.12.2
nlp==0.4.0
nltk==3.7
numpy==1.21.6
packaging==21.3
pandas==1.1.5
pyarrow==8.0.0
pyparsing==3.0.9
python-dateutil==2.8.2
pytz==2022.1
PyYAML==6.0
regex==2022.4.24
requests==2.27.1
responses==0.18.0
rouge-score==0.0.4
sacremoses==0.0.53
scikit-learn==1.0.2
scipy==1.7.3
six==1.16.0
sklearn==0.0
threadpoolctl==3.1.0
tokenizers==0.10.3
tqdm==4.64.0
transformers==4.12.5
typing_extensions==4.2.0
urllib3==1.26.9
wget==3.2
xxhash==3.0.0
yarl==1.7.2
zipp==3.8.0
numpy==1.20.3
rouge_score
omegaconf
jsonlines
spacy

55
diana/readme.md Normal file
View File

@ -0,0 +1,55 @@
# Diana
The PyTorch implementation of paper [Domain Incremental Lifelong Learning in an Open World](https://arxiv.org/abs/2305.06555) (ACL 2023)
## Requirements
```bash
cd DomainIncrementalLL
pip install -r r.txt
```
## QA tasks
### Data preparation
Download datasets, unzip and place in /data_process
[QA Dataset](https://drive.google.com/file/d/1x8vvrdfwKCXpT_M4NA_eso8cX-LlMWsQ/view?usp=share_link)
Tokenize plain text datasets:
```python
cd downstream
python create_l2p.py --model_name_or_path t5-base --dataset_name null --output_dir null
```
### Train and evaluate on 4 GPUs
```python
bash ./diana.sh ./ll 8 48
```
### ablations: No task prompt & No meta prompt
To evaluate ablation without task prompts:
```bash
bash ./dianawotask.sh ./ll 8 48
```
To evaluate ablation without meta prompts:
```bash
bash ./dianawometa.sh ./ll 8 48
```
## DecaNLP tasks
[Dataset](https://drive.google.com/file/d/1pzIoTzooaecsU4HXP_n4-_7Uuxv9i4A-/view?usp=share_link)
extract plain texts from raw files:
```bash
cd downstreamdeca
python extract.py
```
tokenize plain texts:
```bash
python toknize.py
```
Run the training and evaluation script:
```bash
bash ./diana.sh ./ll 8 48
```