add: diana
This commit is contained in:
parent
f570084ce5
commit
99c2ba1195
|
@ -0,0 +1,69 @@
|
|||
#ProQA
|
||||
#[paper]https://arxiv.org/abs/2205.04040
|
||||
class StructuralQAInput:
|
||||
@classmethod
|
||||
def qg_input_abstrativeqa(cls, context, question, options=None):
|
||||
question = question
|
||||
source_text = f'[TASK] [ABSTRACTIVE] [QUESTION] {question}. [CONTEXT] {context} the answer is: '
|
||||
return source_text
|
||||
|
||||
@classmethod
|
||||
def qg_input_boolqa(cls, context, question, options=None):
|
||||
source_text = f'[TASK] [BOOL] [QUESTION] {question}. [CONTEXT] {context} the answer is: '
|
||||
return source_text
|
||||
|
||||
@classmethod
|
||||
def qg_input_extractive_qa(cls, context, question, options=None):
|
||||
source_text = f'[TASK] [EXTRACTIVE] [QUESTION] {question.lower().capitalize()}. [CONTEXT] {context} the answer is: '
|
||||
return source_text
|
||||
|
||||
@classmethod
|
||||
def qg_input_multirc(cls, context, question, options=None):
|
||||
source_text = f'[TASK] [MultiChoice] [QUESTION] {question} [OPTIONS] {options} [CONTEXT] {context} the answer is: '
|
||||
return source_text
|
||||
|
||||
|
||||
|
||||
class LFQAInput:
|
||||
@classmethod
|
||||
def qg_input_abstrativeqa(cls, context, question, options=None):
|
||||
question = question
|
||||
source_text = f'{question} \\n {context}'
|
||||
return source_text
|
||||
|
||||
@classmethod
|
||||
def qg_input_boolqa(cls, context, question, options=None):
|
||||
source_text = f'{question} \\n {context}'
|
||||
return source_text
|
||||
|
||||
@classmethod
|
||||
def qg_input_extractive_qa(cls, context, question, options=None):
|
||||
source_text = f'{question.lower().capitalize()} \\n {context}'
|
||||
return source_text
|
||||
|
||||
@classmethod
|
||||
def qg_input_multirc(cls, context, question, options=None):
|
||||
source_text = f'{question} \\n {options} \\n {context}'
|
||||
return source_text
|
||||
|
||||
class SimpleQAInput:
|
||||
@classmethod
|
||||
def qg_input_abstrativeqa(cls, context, question, options=None):
|
||||
question = question
|
||||
source_text = f'{question} \\n {context}'
|
||||
return source_text
|
||||
|
||||
@classmethod
|
||||
def qg_input_boolqa(cls, context, question, options=None):
|
||||
source_text = f'{question} \\n {context}'
|
||||
return source_text
|
||||
|
||||
@classmethod
|
||||
def qg_input_extractive_qa(cls, context, question, options=None):
|
||||
source_text = f'{question.lower().capitalize()} \\n {context}'
|
||||
return source_text
|
||||
|
||||
@classmethod
|
||||
def qg_input_multirc(cls, context, question, options=None):
|
||||
source_text = f'{question} \\n {options} \\n {context}'
|
||||
return source_text
|
|
@ -0,0 +1,88 @@
|
|||
class QGInput:
|
||||
@classmethod
|
||||
def qg_input_abstractiveqa(cls,context,question,answers,options=None):
|
||||
outputs = []
|
||||
for ans in answers:
|
||||
source_text = f'The context: {context}, the answer is: {ans}. This is a summary task, please generate question: '
|
||||
target_text = f'{question}'
|
||||
outputs.append({'source_text':source_text,'target_text':target_text})
|
||||
return outputs
|
||||
@classmethod
|
||||
def qg_input_abstractiveqa_qapairs(cls,context,question,answers,options=None):
|
||||
outputs = []
|
||||
for ans in answers:
|
||||
source_text = f'Generate question and the answer: The context: {context}.'
|
||||
target_text = f'[Question]: {question} [Answer] {ans}'
|
||||
outputs.append({'source_text': source_text, 'target_text': target_text})
|
||||
return outputs
|
||||
@classmethod
|
||||
def qg_input_boolqa(cls,context,question,answers,options=None):
|
||||
outputs = []
|
||||
for ans in answers:
|
||||
source_text = f'The context: {context}. The answer is: <hl> {ans} <hl>. This is a bool task, generate question: '
|
||||
target_text = f'{question}'
|
||||
outputs.append({'source_text':source_text,'target_text':target_text})
|
||||
return outputs
|
||||
@classmethod
|
||||
def qg_input_extractive_qa(cls, context, question, answers,options=None):
|
||||
outputs = []
|
||||
for ans in answers:
|
||||
context = context.replace(ans,f'<hl> {ans} <hl>')
|
||||
source_text = f'The context: {context}. The answer is: {ans}. This is a extractive task, generate question: '
|
||||
target_text = f'{question}'
|
||||
outputs.append({'source_text': source_text, 'target_text': target_text})
|
||||
return outputs
|
||||
|
||||
@classmethod
|
||||
def qg_input_extractive_qapairs(cls, context, question, answers, options=None):
|
||||
outputs = []
|
||||
for ans in answers:
|
||||
source_text = f'Generate question and the answer: [Context] {context}.'
|
||||
target_text = f'[question] {question} [answer] {ans}'
|
||||
outputs.append({'source_text': source_text, 'target_text': target_text})
|
||||
return outputs
|
||||
@classmethod
|
||||
def qg_input_multirc(cls, context, question, answers, options=None):
|
||||
outputs = []
|
||||
for ans in answers:
|
||||
source_text = f'The context: {context}. The options are: {options}. The answer is: <hl> {ans} <hl>. This is a machine reading comprehension task, generate question: '
|
||||
target_text = f'{question}'
|
||||
outputs.append({'source_text': source_text, 'target_text': target_text})
|
||||
return outputs
|
||||
@classmethod
|
||||
def qg_intput_multirc_negoption(cls, context, question, answers, options=None):
|
||||
outputs = []
|
||||
for ans in answers:
|
||||
for option in options:
|
||||
if option!=ans:
|
||||
source_text = f'Context: {context}. Answer is: <hl> {ans} <hl>. Generate false option: '
|
||||
target_text = f'{option}'
|
||||
outputs.append({'source_text': source_text, 'target_text': target_text})
|
||||
return outputs
|
||||
@classmethod
|
||||
def qg_intput_multirc_qapairs(cls, context, question, answers, options=None):
|
||||
outputs = []
|
||||
for ans in answers:
|
||||
source_text = f'Generate question and answer: [Context] {context}'
|
||||
target_text = f'[Question] {question} [Answer] {ans}'
|
||||
outputs.append({'source_text': source_text, 'target_text': target_text})
|
||||
return outputs
|
||||
@classmethod
|
||||
def qg_intput_multirc_negoption_withq(cls, context, question, answers, options=None):
|
||||
outputs = []
|
||||
for ans in answers:
|
||||
for option in options:
|
||||
if option != ans:
|
||||
source_text = f'Generate false option: [Context] {context}. [Question] {question}. [Answer] <hl> {ans} <hl>.'
|
||||
target_text = f'{option}'
|
||||
outputs.append({'source_text': source_text, 'target_text': target_text})
|
||||
return outputs
|
||||
|
||||
@classmethod
|
||||
def qg_intput_multirc_qfirst(cls, context, question, answers, options=None):
|
||||
outputs = []
|
||||
for ans in answers:
|
||||
source_text = f'Generate question: Context: {context}. Answer: <hl> {ans} <hl>. '
|
||||
target_text = f'{question}'
|
||||
outputs.append({'source_text': source_text, 'target_text': target_text})
|
||||
return outputs
|
|
@ -0,0 +1,156 @@
|
|||
import logging
|
||||
import t5
|
||||
import os
|
||||
import json
|
||||
import functools
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
DATASETS = [
|
||||
"ai2_science_middle",
|
||||
"ai2_science_elementary",
|
||||
"arc_hard",
|
||||
"arc_easy",
|
||||
"mctest",
|
||||
"mctest_corrected_the_separator",
|
||||
"natural_questions",
|
||||
"quoref",
|
||||
"squad1_1",
|
||||
"squad2",
|
||||
"boolq",
|
||||
"multirc",
|
||||
"newsqa",
|
||||
"race_string",
|
||||
"ropes",
|
||||
"drop",
|
||||
"narrativeqa",
|
||||
"openbookqa",
|
||||
"qasc",
|
||||
"boolq_np",
|
||||
"contrast_sets_boolq",
|
||||
"contrast_sets_drop",
|
||||
"contrast_sets_quoref",
|
||||
"contrast_sets_ropes",
|
||||
"commonsenseqa",
|
||||
"qasc_with_ir",
|
||||
"openbookqa_with_ir",
|
||||
"arc_easy_with_ir",
|
||||
"arc_hard_with_ir",
|
||||
"ambigqa",
|
||||
"natural_questions_direct_ans",
|
||||
"natural_questions_with_dpr_para",
|
||||
"winogrande_xs",
|
||||
"winogrande_s",
|
||||
"winogrande_m",
|
||||
"winogrande_l",
|
||||
"winogrande_xl",
|
||||
"social_iqa",
|
||||
"physical_iqa",
|
||||
]
|
||||
|
||||
DATA_DIR = f"gs://unifiedqa/data/"
|
||||
|
||||
def dataset_preprocessor(ds):
|
||||
def normalize_text(text):
|
||||
"""Lowercase and remove quotes from a TensorFlow string."""
|
||||
text = tf.strings.lower(text)
|
||||
text = tf.strings.regex_replace(text, "'(.*)'", r"\1")
|
||||
return text
|
||||
|
||||
def to_inputs_and_targets(ex):
|
||||
return {
|
||||
"inputs": normalize_text(ex["inputs"]),
|
||||
"targets": normalize_text(ex["targets"])
|
||||
}
|
||||
|
||||
return ds.map(to_inputs_and_targets,
|
||||
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
|
||||
def get_path(data_dir1, split):
|
||||
tsv_path = {
|
||||
"train": os.path.join(data_dir1, "train.tsv"),
|
||||
"dev": os.path.join(data_dir1, "dev.tsv"),
|
||||
"test": os.path.join(data_dir1, "test.tsv")
|
||||
}
|
||||
return tsv_path[split]
|
||||
|
||||
|
||||
def dataset_fn(split, shuffle_files=False, dataset=""):
|
||||
# We only have one file for each split.
|
||||
del shuffle_files
|
||||
|
||||
# Load lines from the text file as examples.
|
||||
ds = tf.data.TextLineDataset(get_path(DATA_DIR + dataset, split))
|
||||
# Split each "<question>\t<answer>" example into (question, answer) tuple.
|
||||
print(" >>>> about to read csv . . . ")
|
||||
ds = ds.map(
|
||||
functools.partial(tf.io.decode_csv, record_defaults=["", ""],
|
||||
field_delim="\t", use_quote_delim=False),
|
||||
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
# print(" >>>> after reading csv . . . ")
|
||||
# Map each tuple to a {"question": ... "answer": ...} dict.
|
||||
ds = ds.map(lambda *ex: dict(zip(["inputs", "targets"], ex)))
|
||||
# print(" >>>> after mapping . . . ")
|
||||
return ds
|
||||
|
||||
|
||||
for dataset in DATASETS:
|
||||
print(f" >>>> reading dataset: {dataset}")
|
||||
t5.data.set_tfds_data_dir_override(DATA_DIR + dataset)
|
||||
t5.data.TaskRegistry.add(
|
||||
f"{dataset}_task",
|
||||
# Supply a function which returns a tf.data.Dataset.
|
||||
dataset_fn=functools.partial(dataset_fn, dataset=dataset),
|
||||
splits=["train", "dev", "test"],
|
||||
# Supply a function which preprocesses text from the tf.data.Dataset.
|
||||
text_preprocessor=[dataset_preprocessor],
|
||||
# Use the same vocabulary that we used for pre-training.
|
||||
sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
|
||||
# Lowercase targets before computing metrics.
|
||||
postprocess_fn=t5.data.postprocessors.lower_text,
|
||||
metric_fns=[t5.evaluation.metrics.accuracy],
|
||||
)
|
||||
print(f" >>>> adding one mixture per dataset: `{dataset}_mixture`")
|
||||
t5.data.MixtureRegistry.add(
|
||||
f"{dataset}_mixture", [f"{dataset}_task"], default_rate=1.0
|
||||
)
|
||||
|
||||
# dataset-pair mixtures
|
||||
for dataset1 in DATASETS:
|
||||
for dataset2 in DATASETS:
|
||||
if dataset1 == dataset2:
|
||||
continue
|
||||
print(f" >>>> adding one mixture for dataset-pair: `{dataset1}_{dataset2}_mixture`")
|
||||
t5.data.MixtureRegistry.add(
|
||||
f"{dataset1}_{dataset2}_mixture",
|
||||
[f"{dataset1}_task", f"{dataset2}_task"],
|
||||
default_rate=1.0
|
||||
)
|
||||
|
||||
# union model: used for training UnifiedQA
|
||||
union_datasets = [
|
||||
"narrativeqa",
|
||||
"ai2_science_middle", "ai2_science_elementary",
|
||||
"arc_hard", "arc_easy",
|
||||
"mctest_corrected_the_separator",
|
||||
"squad1_1", "squad2",
|
||||
"boolq",
|
||||
"race_string",
|
||||
"openbookqa",
|
||||
]
|
||||
print(f" >>>> adding one mixture for `union_mixture`")
|
||||
t5.data.MixtureRegistry.add(
|
||||
f"union_mixture",
|
||||
[f"{d}_task" for d in union_datasets],
|
||||
default_rate=1.0
|
||||
)
|
||||
|
||||
# leave-one-out
|
||||
for dd in union_datasets:
|
||||
filtered_datasets = [f"{d}_task" for d in union_datasets if d != dd]
|
||||
assert len(filtered_datasets) < len(union_datasets)
|
||||
t5.data.MixtureRegistry.add(
|
||||
f"union_minus_{dd}_mixture",
|
||||
filtered_datasets,
|
||||
default_rate=1.0
|
||||
)
|
|
@ -0,0 +1,782 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
import copy,random
|
||||
import sys
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
# from data import QAData
|
||||
sys.path.append("../")
|
||||
from models.nopt5 import T5ForConditionalGeneration as PromptT5
|
||||
# from models.promptT5 import PromptT5
|
||||
from downstream.dataset_processors import *
|
||||
from downstream.trainer import QuestionAnsweringTrainer
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_metric,load_from_disk
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
M2M100Tokenizer,
|
||||
MBart50Tokenizer,
|
||||
EvalPrediction,
|
||||
MBart50TokenizerFast,
|
||||
MBartTokenizer,
|
||||
MBartTokenizerFast,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
default_data_collator,
|
||||
set_seed,
|
||||
)
|
||||
from pathlib import Path
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
# check_min_version("4.12.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
|
||||
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
|
||||
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
|
||||
format2dataset = {
|
||||
'extractive':['squad','extractive','newsqa','quoref'], #,'ropes'
|
||||
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
|
||||
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
|
||||
}
|
||||
seed_datasets = ['race','narrativeqa','squad']
|
||||
dataset2format= {}
|
||||
task2id = {}
|
||||
for k,vs in format2dataset.items():
|
||||
for v in vs:
|
||||
dataset2format[v] = k
|
||||
task2format={}
|
||||
for k,v in dataset2format.items():
|
||||
if v=="extractive":
|
||||
task2format[k]=0
|
||||
elif v=="abstractive":
|
||||
task2format[k]=1
|
||||
else:
|
||||
task2format[k]=2
|
||||
# task2id[v] = len(task2id.keys())
|
||||
# task2id = {v:k for k,v in enumerate(task_list)}
|
||||
task2id = {'squad': 0, 'extractive': 1, 'narrativeqa': 2, 'abstractive': 3, 'race': 4, 'multichoice': 5, 'boolq': 6, 'bool': 7, 'newsqa':8,'quoref':9,'ropes':10,'drop':11,'nqopen':12,'boolq_np':13,'openbookqa':14,'mctest':15,'social_iqa':16,'dream':17}
|
||||
|
||||
# task2id = {'squad': 0, 'extractive': 0, 'narrativeqa': 1, 'abstractive':1 , 'race': 2, 'multichoice': 2, 'boolq': 3, 'bool': 3, 'newsqa':8,'quoref':9,'ropes':10,'drop':11,'nqopen':12,'boolq_np':13,'openbookqa':14,'mctest':15,'social_iqa':16,'dream':17}
|
||||
# print(task2id)
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
fix_t5: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to fix the main parameters of t5"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
|
||||
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
|
||||
"a jsonlines file."
|
||||
},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
after_train: Optional[str] = field(default=None)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=384,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||
"efficient on GPU but very bad for TPU."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||
"which is used during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
},
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
)
|
||||
forced_bos_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
|
||||
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
|
||||
"needs to be the target language token.(Usually it is the target language token)"
|
||||
},
|
||||
)
|
||||
do_lowercase: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
'help':'Whether to process input into lowercase'
|
||||
}
|
||||
)
|
||||
append_another_bos: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
'help':'Whether to add another bos'
|
||||
}
|
||||
)
|
||||
context_column: Optional[str] = field(
|
||||
default="context",
|
||||
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
|
||||
)
|
||||
question_column: Optional[str] = field(
|
||||
default="question",
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
answer_column: Optional[str] = field(
|
||||
default="answers",
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
max_task_num: Optional[int] = field(
|
||||
default=30,
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
qa_task_type_num: Optional[int] = field(
|
||||
default=4,
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
prompt_number: Optional[int] = field(
|
||||
default=40,
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
add_task_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
reload_from_trained_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to reload prompt from trained prompt"},
|
||||
)
|
||||
trained_prompt_path:Optional[str] = field(
|
||||
default = None,
|
||||
metadata={
|
||||
'help':'the path storing trained prompt embedding'
|
||||
}
|
||||
)
|
||||
load_from_format_task_id: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
|
||||
)
|
||||
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
# elif self.source_lang is None or self.target_lang is None:
|
||||
# raise ValueError("Need to specify the source language and the target language.")
|
||||
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
# assert extension == "json", "`train_file` should be a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
# assert extension == "json", "`validation_file` should be a json file."
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
|
||||
# +
|
||||
def main():
|
||||
|
||||
def preprocess_valid_function(examples):
|
||||
preprocess_fn = preprocess_proqa#dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, "question","context","answer")
|
||||
model_inputs = tokenizer(inputs, max_length=1024, padding=False, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=128, padding=False, truncation=True)
|
||||
model_inputs["example_id"] = []
|
||||
|
||||
for i in range(len(model_inputs["input_ids"])):
|
||||
# One example can give several spans, this is the index of the example containing this span of text.
|
||||
sample_index = i #sample_mapping[i]
|
||||
model_inputs["example_id"].append(examples["id"][sample_index])
|
||||
|
||||
|
||||
|
||||
gen_prompt_ids = [-(i+1) for i in range(1520,1520+20)]
|
||||
format_id = task2format[dataset_name]
|
||||
format_prompt_id_start = 500
|
||||
format_prompt_ids = [-(i+1) for i in range(format_prompt_id_start + format_id * 40,
|
||||
format_prompt_id_start + (format_id + 1) * 40)]
|
||||
task_id = task2id[dataset_name]
|
||||
task_prompt_id_start = 0
|
||||
task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * 40,
|
||||
task_prompt_id_start + (task_id + 1) * 40)]
|
||||
|
||||
|
||||
domain_prompt_id_start = 800
|
||||
domain_prompt_number = 20
|
||||
domain_prompt_ids = [- (i + 1) for i in range(domain_prompt_id_start,
|
||||
domain_prompt_id_start + 20)]*5
|
||||
input_ids = copy.deepcopy(
|
||||
[gen_prompt_ids+format_prompt_ids+task_prompt_ids + domain_prompt_ids+input_ids for input_ids in model_inputs['input_ids']])
|
||||
model_inputs['input_ids'] = input_ids # [format_prompt_ids+input_ids for input_ids in model_inputs['input_ids']]
|
||||
model_inputs['attention_mask'] = [[1] * 200 + attention_mask for attention_mask in
|
||||
model_inputs['attention_mask']]
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
|
||||
return model_inputs
|
||||
|
||||
def compute_metrics(p: EvalPrediction):
|
||||
return metric.compute(predictions=p.predictions, references=p.label_ids)
|
||||
|
||||
|
||||
def save_prompt_embedding(model,path):
|
||||
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
|
||||
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
|
||||
prompt_path = os.path.join(path,'prompt_embedding_info')
|
||||
torch.save(save_prompt_info,prompt_path)
|
||||
logger.info(f'Saving prompt embedding information to {prompt_path}')
|
||||
|
||||
|
||||
def preprocess_function(examples):
|
||||
preprocess_fn = preprocess_proqa#dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, "question","context","answer")
|
||||
model_inputs = tokenizer(inputs, max_length=1024, padding=False, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(targets, max_length=128, padding=False, truncation=True)
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
gen_prompt_ids = [-(i+1) for i in range(1520,1520+20)]
|
||||
format_id = task2format[dataset_name]
|
||||
format_prompt_id_start = 500
|
||||
format_prompt_ids = [-(i+1) for i in range(format_prompt_id_start + format_id * 40,
|
||||
format_prompt_id_start + (format_id + 1) * 40)]
|
||||
task_id = task2id[dataset_name]
|
||||
task_prompt_id_start = 0
|
||||
task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * 40,
|
||||
task_prompt_id_start + (task_id + 1) * 40)]
|
||||
|
||||
|
||||
domain_prompt_id_start = 800
|
||||
domain_prompt_number = 20
|
||||
domain_prompt_ids = [- (i + 1) for i in range(domain_prompt_id_start,
|
||||
domain_prompt_id_start + 20)]*5
|
||||
input_ids = copy.deepcopy(
|
||||
[gen_prompt_ids+format_prompt_ids+task_prompt_ids + domain_prompt_ids+input_ids for input_ids in model_inputs['input_ids']])
|
||||
model_inputs['input_ids'] = input_ids # [format_prompt_ids+input_ids for input_ids in model_inputs['input_ids']]
|
||||
model_inputs['attention_mask'] = [[1] * 200 + attention_mask for attention_mask in
|
||||
model_inputs['attention_mask']]
|
||||
return model_inputs
|
||||
|
||||
|
||||
|
||||
def dataset_name_to_func(dataset_name):
|
||||
mapping = {
|
||||
'squad': preprocess_sqaud_batch,
|
||||
'squad_v2': preprocess_sqaud_batch,
|
||||
'boolq': preprocess_boolq_batch,
|
||||
'narrativeqa': preprocess_narrativeqa_batch,
|
||||
'race': preprocess_race_batch,
|
||||
'newsqa': preprocess_newsqa_batch,
|
||||
'quoref': preprocess_sqaud_batch,
|
||||
'ropes': preprocess_ropes_batch,
|
||||
'drop': preprocess_drop_batch,
|
||||
'nqopen': preprocess_sqaud_abstractive_batch,
|
||||
# 'multirc': preprocess_boolq_batch,
|
||||
'boolq_np': preprocess_boolq_batch,
|
||||
'openbookqa': preprocess_openbookqa_batch,
|
||||
'mctest': preprocess_race_batch,
|
||||
'social_iqa': preprocess_social_iqa_batch,
|
||||
'dream': preprocess_dream_batch,
|
||||
}
|
||||
return mapping[dataset_name]
|
||||
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
if 'same' in model_args.model_name_or_path:
|
||||
task2id = {'squad': 0, 'extractive': 0, 'narrativeqa': 1, 'abstractive': 1, 'race': 2, 'multichoice': 2,
|
||||
'boolq': 3, 'bool': 3, 'newsqa': 8, 'quoref': 9, 'ropes': 10, 'drop': 11, 'nqopen': 12,
|
||||
'boolq_np': 13, 'openbookqa': 14, 'mctest': 15, 'social_iqa': 16, 'dream': 17}
|
||||
else:
|
||||
task2id = {'squad': 0, 'extractive': 1, 'narrativeqa': 2, 'abstractive': 3, 'race': 4, 'multichoice': 5,
|
||||
'boolq': 6, 'bool': 7, 'newsqa': 8, 'quoref': 9, 'ropes': 10, 'drop': 11, 'nqopen': 12,
|
||||
'boolq_np': 13, 'openbookqa': 14, 'mctest': 15, 'social_iqa': 16, 'dream': 17}
|
||||
|
||||
dataset_name_to_metric = {
|
||||
'squad': 'squad',
|
||||
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
|
||||
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
|
||||
'boolq': 'accuracy',
|
||||
'narrativeqa': 'rouge',
|
||||
'race': 'accuracy',
|
||||
'quoref': 'squad',
|
||||
'ropes': 'squad',
|
||||
'drop': 'squad',
|
||||
'nqopen': 'squad',
|
||||
# 'multirc': 'accuracy',
|
||||
'boolq_np': 'accuracy',
|
||||
'openbookqa': 'accuracy',
|
||||
'mctest': 'accuracy',
|
||||
'social_iqa': 'accuracy',
|
||||
'dream': 'accuracy',
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
if data_args.source_prefix is None and model_args.model_name_or_path in [
|
||||
"t5-small",
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
"t5-3b",
|
||||
"t5-11b",
|
||||
]:
|
||||
logger.warning(
|
||||
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
|
||||
"`--source_prefix 'translate English to German: ' `"
|
||||
)
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
|
||||
# '[OPTIONS]'])
|
||||
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
|
||||
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
|
||||
'[OPTIONS]']}
|
||||
|
||||
tokenizer.add_tokens(tokens_to_add)
|
||||
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
added_tokens = tokenizer.get_added_vocab()
|
||||
logger.info('Added tokens: {}'.format(added_tokens))
|
||||
|
||||
model = PromptT5.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
# task_num = data_args.max_task_num,
|
||||
# prompt_num = data_args.prompt_number,
|
||||
# format_num = data_args.qa_task_type_num,
|
||||
# add_task_prompt = False
|
||||
)
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
|
||||
#reload format specific task-prompt for newly involved task
|
||||
#format_prompts###task_promptsf
|
||||
|
||||
data_args.reload_from_trained_prompt = False#@
|
||||
data_args.load_from_format_task_id = False#@
|
||||
### before pretrain come !!!!!!
|
||||
|
||||
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
|
||||
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
||||
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
||||
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
|
||||
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
|
||||
# print(dataset2format[data_args.dataset_name])
|
||||
# print(data_args.dataset_name)
|
||||
elif data_args.reload_from_trained_prompt:
|
||||
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
|
||||
prompt_info = torch.load(data_args.trained_prompt_path)
|
||||
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
|
||||
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
|
||||
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
||||
|
||||
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
||||
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
|
||||
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
|
||||
format_id = format2id[dataset2format[data_args.dataset_name]]
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
|
||||
logger.info(
|
||||
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
|
||||
|
||||
# Set decoder_start_token_id
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||
if isinstance(tokenizer, MBartTokenizer):
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
|
||||
else:
|
||||
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
|
||||
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
|
||||
|
||||
if training_args.local_rank == -1 or training_args.no_cuda:
|
||||
device = torch.device("cuda")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Temporarily set max_target_length for training.
|
||||
max_target_length = data_args.max_target_length
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||
logger.warning(
|
||||
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
|
||||
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
|
||||
)
|
||||
|
||||
|
||||
question_column = data_args.question_column
|
||||
context_column = data_args.context_column
|
||||
answer_column = data_args.answer_column
|
||||
# import random
|
||||
if data_args.max_source_length > tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
|
||||
# Data collator
|
||||
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
if data_args.pad_to_max_length:
|
||||
data_collator = default_data_collator
|
||||
else:
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
model=model,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
)
|
||||
|
||||
#start
|
||||
train_dataloaders = {}
|
||||
eval_dataloaders = {}
|
||||
replay_dataloaders = {}
|
||||
for ds_name in task2id.keys():
|
||||
if (not ds_name in ["extractive","abstractive","multichoice","bool","boolq","boolq_np","ropes"]):
|
||||
# data_args.dataset_name = cur_dataset
|
||||
cur_dataset = ds_name
|
||||
data_args.dataset_name = cur_dataset
|
||||
if True:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
if not data_args.dataset_name in ['newsqa', 'nqopen', 'mctest', 'social_iqa']:
|
||||
if data_args.dataset_name == "race":
|
||||
data_args.dataset_config_name = "all"
|
||||
elif data_args.dataset_name == "openbookqa":
|
||||
data_args.dataset_config_name = "main"
|
||||
elif data_args.dataset_name == "dream":
|
||||
data_args.dataset_config_name = "plain_text"
|
||||
else:
|
||||
data_args.dataset_config_name = "plain_text"
|
||||
raw_datasets = load_dataset(
|
||||
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
|
||||
)
|
||||
if data_args.dataset_name in ['ropes']:
|
||||
# add answer_start (not used for squad evaluation but required)
|
||||
def add_answer_start(example):
|
||||
example['answers'].update({"answer_start": [0]})
|
||||
return example
|
||||
raw_datasets = raw_datasets.map(add_answer_start)
|
||||
elif data_args.dataset_name in ['drop']:
|
||||
# add answer_start (not used for squad evaluation but required)
|
||||
# add answers (for squad evaluation)
|
||||
def add_answers(example):
|
||||
answers = []
|
||||
answer_start = []
|
||||
for _a in example['answers_spans']['spans']:
|
||||
answers.append(_a)
|
||||
answer_start.append(-1)
|
||||
example['answers'] = {"text": answers, "answer_start": answer_start}
|
||||
return example
|
||||
raw_datasets = raw_datasets.map(add_answers)
|
||||
column_names = raw_datasets["validation"].column_names
|
||||
|
||||
|
||||
else:
|
||||
data_files = {}
|
||||
basic_file = "../data_process/data/"+data_args.dataset_name+"/"
|
||||
data_files["train"] = basic_file+"train.json"
|
||||
# extension = data_args.train_file.split(".")[-1]
|
||||
|
||||
data_files["validation"] = basic_file+"dev.json"
|
||||
# extension = data_args.validation_file.split(".")[-1]
|
||||
|
||||
if data_args.dataset_name in ['newsqa', 'nqopen', 'multirc', 'boolq_np', 'mctest', 'social_iqa']:
|
||||
raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir)
|
||||
else:
|
||||
print(f"Unknown dataset {data_args.dataset_name}")
|
||||
raise NotImplementedError
|
||||
column_names = raw_datasets["validation"].column_names
|
||||
metric = load_metric(dataset_name_to_metric[data_args.dataset_name])
|
||||
|
||||
if "train" not in raw_datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = raw_datasets["train"]
|
||||
|
||||
if True:
|
||||
all_num = list(range(0, len(train_dataset)))
|
||||
random.shuffle(all_num)
|
||||
selected_indices = all_num[:100]
|
||||
replay_dataset = train_dataset.select(selected_indices)
|
||||
|
||||
# train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
|
||||
with training_args.main_process_first(desc="train dataset map pre-processing"):
|
||||
replay_dataset = replay_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=True,
|
||||
desc="Running tokenizer on replay dataset",
|
||||
)
|
||||
replay_dataloaders[ds_name] = replay_dataset
|
||||
# replay_dataset = load_from_disk("./processed/{}-replay.hf".format(ds_name))
|
||||
# print(replay_dataset)
|
||||
with training_args.main_process_first(desc="train dataset map pre-processing"):
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=True,
|
||||
desc="Running tokenizer on train dataset",
|
||||
)
|
||||
# print(train_dataset)
|
||||
train_dataloaders[ds_name] = train_dataset
|
||||
train_dataset.save_to_disk("./ours/{}-train.hf".format(ds_name))
|
||||
# train_dataset = load_from_disk("./processed/{}-train.hf".format(ds_name))
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "validation" not in raw_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_examples = raw_datasets["validation"]
|
||||
def add_id(example,index):
|
||||
example.update({'id':index})
|
||||
return example
|
||||
if 'id' not in eval_examples.features.keys():
|
||||
eval_examples = eval_examples.map(add_id,with_indices=True)
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
|
||||
with training_args.main_process_first(desc="validation dataset map pre-processing"):
|
||||
eval_dataset = eval_examples.map(
|
||||
preprocess_validation_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=True,
|
||||
desc="Running tokenizer on validation dataset",
|
||||
)
|
||||
eval_dataloaders[ds_name] = (eval_dataset,eval_examples)
|
||||
eval_dataset.save_to_disk("./ours/{}-eval.hf".format(ds_name))
|
||||
|
||||
eval_examples.save_to_disk("./ours/{}-evalex.hf".format(ds_name))
|
||||
|
||||
|
||||
|
||||
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
|
||||
if len(languages) > 0:
|
||||
kwargs["language"] = languages
|
||||
return None
|
||||
|
||||
|
||||
# -
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,192 @@
|
|||
import sys
|
||||
sys.path.append('../data_process')
|
||||
from typing import List, Optional, Tuple
|
||||
from QAInput import StructuralQAInput as QAInput
|
||||
#from QAInput import QAInputNoPrompt as QAInput
|
||||
|
||||
def preprocess_sqaud_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_sqaud_abstractive_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_boolq_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
question_column, context_column, answer_column = 'question', 'passage', 'answer'
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [str(ans) for ans in answers] #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
# print(inputs,targets)
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_boolq_batch_pretrain(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0].capitalize() if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
# print(inputs,targets)
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_narrativeqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = [exp['summary']['text'] for exp in examples['document']]
|
||||
questions = [exp['text'] for exp in examples['question']]
|
||||
answers = [ans[0]['text'] for ans in examples['answers']]
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = answers #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_narrativeqa_batch_pretrain(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_drop_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = examples['passage']
|
||||
questions = examples['question']
|
||||
answers = examples['answers']
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_race_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = examples['article']
|
||||
questions = examples['question']
|
||||
all_options = examples['options']
|
||||
answers = examples['answer']
|
||||
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}; D. {options[3]}' for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
|
||||
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3 }
|
||||
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_newsqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
|
||||
# inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
|
||||
def preprocess_ropes_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
backgrounds = examples["background"]
|
||||
situations = examples["situation"]
|
||||
answers = examples[answer_column]
|
||||
|
||||
inputs = [QAInput.qg_input_extractive_qa(" ".join([background, situation]), question) for question, background, situation in zip(questions, backgrounds, situations)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_openbookqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples['question_stem']
|
||||
all_options = examples['choices']
|
||||
answers = examples['answerKey']
|
||||
options_texts = [f"options: A. {options['text'][0]}; B. {options['text'][1]}; C. {options['text'][2]}; D. {options['text'][3]}" for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc("", question, ops) for question, ops in zip(questions, options_texts)]
|
||||
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
|
||||
targets = [options['text'][ans_map[answer]] for options, answer in zip(all_options, answers)]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_social_iqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = examples['article']
|
||||
questions = examples['question']
|
||||
all_options = examples['options']
|
||||
answers = examples['answer']
|
||||
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
|
||||
ans_map = {'A': 0, 'B': 1, 'C': 2,}
|
||||
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_dream_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = [" ".join(dialogue) for dialogue in examples['dialogue']]
|
||||
questions = examples['question']
|
||||
all_options = examples['choice']
|
||||
answers = examples['answer']
|
||||
answer_idxs = [options.index(answer) for answer, options in zip(answers, all_options)]
|
||||
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
|
||||
targets = answers
|
||||
return inputs, targets
|
|
@ -0,0 +1,53 @@
|
|||
function fulldata_hfdata() {
|
||||
SAVING_PATH=$1
|
||||
PRETRAIN_MODEL_PATH=$2
|
||||
TRAIN_BATCH_SIZE=$3
|
||||
EVAL_BATCH_SIZE=$4
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
mkdir -p ${SAVING_PATH}
|
||||
|
||||
python -u diana.py \
|
||||
--model_name_or_path ${PRETRAIN_MODEL_PATH} \
|
||||
--output_dir ${SAVING_PATH} \
|
||||
--dataset_name squad \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--per_device_train_batch_size ${TRAIN_BATCH_SIZE} \
|
||||
--per_device_eval_batch_size ${EVAL_BATCH_SIZE} \
|
||||
--overwrite_output_dir \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 5 \
|
||||
--warmup_ratio 0.1 \
|
||||
--logging_steps 5 \
|
||||
--learning_rate 1e-4 \
|
||||
--predict_with_generate \
|
||||
--num_beams 4 \
|
||||
--save_strategy no \
|
||||
--evaluation_strategy no \
|
||||
--weight_decay 1e-2 \
|
||||
--max_source_length 512 \
|
||||
--label_smoothing_factor 0.1 \
|
||||
--do_lowercase True \
|
||||
--load_best_model_at_end True \
|
||||
--greater_is_better True \
|
||||
--save_total_limit 10 \
|
||||
--ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log
|
||||
}
|
||||
|
||||
SAVING_PATH=$1
|
||||
TRAIN_BATCH_SIZE=$2
|
||||
EVAL_BATCH_SIZE=$3
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
MODE=try
|
||||
|
||||
|
||||
SAVING_PATH=${SAVING_PATH}/lifelong/${MODE}
|
||||
fulldata_hfdata ${SAVING_PATH} t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE}
|
|
@ -0,0 +1,827 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
import logging
|
||||
import os
|
||||
os.environ['NCCL_ASYNC_ERROR_HANDLING']='1'
|
||||
import torch
|
||||
import copy,random
|
||||
import sys
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from sklearn.cluster import KMeans
|
||||
# from data import QAData
|
||||
sys.path.append("../")
|
||||
from models.metat5 import T5ForConditionalGeneration as PromptT5
|
||||
from downstream.dataset_processors import *
|
||||
from downstream.l2ptrainer import QuestionAnsweringTrainer
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
M2M100Tokenizer,
|
||||
MBart50Tokenizer,
|
||||
EvalPrediction,
|
||||
MBart50TokenizerFast,
|
||||
MBartTokenizer,
|
||||
MBartTokenizerFast,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
default_data_collator,
|
||||
set_seed,
|
||||
)
|
||||
from pathlib import Path
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
|
||||
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
|
||||
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
|
||||
format2dataset = {
|
||||
'extractive':['squad','extractive','newsqa','quoref'],
|
||||
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
|
||||
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
|
||||
}
|
||||
seed_datasets = ['race','narrativeqa','squad']
|
||||
dataset2format= {}
|
||||
|
||||
for k,vs in format2dataset.items():
|
||||
for v in vs:
|
||||
dataset2format[v] = k
|
||||
|
||||
task2id = {'squad': 0, 'narrativeqa': 1, 'race': 2, 'newsqa':3,'quoref':4,'drop':5,'nqopen':6,'openbookqa':7,'mctest':8,'social_iqa':9,'dream':10}
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
fix_t5: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to fix the main parameters of t5"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
|
||||
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
|
||||
"a jsonlines file."
|
||||
},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
after_train: Optional[str] = field(default=None)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=384,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||
"efficient on GPU but very bad for TPU."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||
"which is used during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
},
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
)
|
||||
forced_bos_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
|
||||
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
|
||||
"needs to be the target language token.(Usually it is the target language token)"
|
||||
},
|
||||
)
|
||||
do_lowercase: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
'help':'Whether to process input into lowercase'
|
||||
}
|
||||
)
|
||||
append_another_bos: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
'help':'Whether to add another bos'
|
||||
}
|
||||
)
|
||||
context_column: Optional[str] = field(
|
||||
default="context",
|
||||
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
|
||||
)
|
||||
question_column: Optional[str] = field(
|
||||
default="question",
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
answer_column: Optional[str] = field(
|
||||
default="answers",
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
max_task_num: Optional[int] = field(
|
||||
default=30,
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
qa_task_type_num: Optional[int] = field(
|
||||
default=4,
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
prompt_number: Optional[int] = field(
|
||||
default=100,
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
add_task_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
reload_from_trained_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to reload prompt from trained prompt"},
|
||||
)
|
||||
trained_prompt_path:Optional[str] = field(
|
||||
default = None,
|
||||
metadata={
|
||||
'help':'the path storing trained prompt embedding'
|
||||
}
|
||||
)
|
||||
load_from_format_task_id: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
|
||||
)
|
||||
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
# elif self.source_lang is None or self.target_lang is None:
|
||||
# raise ValueError("Need to specify the source language and the target language.")
|
||||
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
# assert extension == "json", "`train_file` should be a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
# assert extension == "json", "`validation_file` should be a json file."
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
|
||||
# +
|
||||
def main():
|
||||
|
||||
|
||||
|
||||
def preprocess_validation_function(examples):
|
||||
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
model_inputs["example_id"] = []
|
||||
|
||||
for i in range(len(model_inputs["input_ids"])):
|
||||
# One example can give several spans, this is the index of the example containing this span of text.
|
||||
sample_index = i #sample_mapping[i]
|
||||
model_inputs["example_id"].append(examples["id"][sample_index])
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
if True:
|
||||
pass
|
||||
# logger.info(f'Loading task {data_args.dataset_name} prompt')
|
||||
# format_id = format2id[dataset2format[data_args.dataset_name]]
|
||||
# task_id = task2id[data_args.dataset_name]
|
||||
# format_prompt_ids = [- (i + 1) for i in range(format_id * data_args.prompt_number, (
|
||||
# format_id + 1) * data_args.prompt_number)] # list(range(-(format_id * data_args.prompt_number+1), -((format_id + 1) * data_args.prompt_number+1)))
|
||||
# task_prompt_id_start = len(format2id.keys()) * data_args.prompt_number
|
||||
# logger.info('Prompt ids {}: {}'.format(task_prompt_id_start + task_id * data_args.prompt_number,
|
||||
# task_prompt_id_start + (task_id + 1) * data_args.prompt_number))
|
||||
# task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * data_args.prompt_number,
|
||||
# task_prompt_id_start + (task_id + 1) * data_args.prompt_number)]
|
||||
# input_ids = copy.deepcopy(
|
||||
# [format_prompt_ids + task_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
|
||||
#input_ids = copy.deepcopy([format_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
|
||||
# model_inputs['input_ids'] = input_ids
|
||||
# model_inputs['attention_mask'] = [[1] * data_args.prompt_number * 2 + attention_mask for attention_mask in
|
||||
# model_inputs['attention_mask']]
|
||||
|
||||
# input_ids = copy.deepcopy([input])
|
||||
return model_inputs
|
||||
def compute_metrics(p: EvalPrediction):
|
||||
return metric.compute(predictions=p.predictions, references=p.label_ids)
|
||||
|
||||
|
||||
def save_prompt_embedding(model,path):
|
||||
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
|
||||
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
|
||||
prompt_path = os.path.join(path,'prompt_embedding_info')
|
||||
torch.save(save_prompt_info,prompt_path)
|
||||
logger.info(f'Saving prompt embedding information to {prompt_path}')
|
||||
|
||||
|
||||
|
||||
def preprocess_function(examples):
|
||||
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
def save_load_diverse_sample(model,trainset):
|
||||
with torch.no_grad():
|
||||
device = 'cuda:0'
|
||||
search_upbound = len(trainset)//4
|
||||
query_idxs = [None]*30
|
||||
keys = model.encoder.domain_keys
|
||||
for idx,item in enumerate(trainset.select(range(search_upbound))):
|
||||
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
|
||||
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
|
||||
return_dict=True)
|
||||
result = torch.matmul(query,keys.t())
|
||||
result = torch.topk(result,5).indices[0].cpu().numpy().tolist()
|
||||
key_sel = None
|
||||
for key_idx in result:
|
||||
if query_idxs[key_idx] is None or len(query_idxs[key_idx])<3:
|
||||
key_sel = key_idx
|
||||
break
|
||||
if key_sel is not None:
|
||||
if query_idxs[key_sel] is None:
|
||||
query_idxs[key_sel] = [idx]
|
||||
else:
|
||||
query_idxs[key_sel].append(idx)
|
||||
total_idxs = []
|
||||
for item in query_idxs:
|
||||
try:
|
||||
total_idxs.extend(item[:3])
|
||||
except:
|
||||
total_idxs.extend(random.sample(list(range(search_upbound,len(trainset))),3))
|
||||
total_idxs = list(set(total_idxs))
|
||||
total_idxs = random.sample(total_idxs,50)
|
||||
sub_set = trainset.select(total_idxs)
|
||||
|
||||
features = []
|
||||
for idx,item in enumerate(sub_set):
|
||||
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
|
||||
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
|
||||
return_dict=True)
|
||||
features.append(query.detach().cpu().numpy())
|
||||
|
||||
|
||||
|
||||
return sub_set,features
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def dataset_name_to_func(dataset_name):
|
||||
mapping = {
|
||||
'squad': preprocess_sqaud_batch,
|
||||
'squad_v2': preprocess_sqaud_batch,
|
||||
'boolq': preprocess_boolq_batch,
|
||||
'narrativeqa': preprocess_narrativeqa_batch,
|
||||
'race': preprocess_race_batch,
|
||||
'newsqa': preprocess_newsqa_batch,
|
||||
'quoref': preprocess_sqaud_batch,
|
||||
'ropes': preprocess_ropes_batch,
|
||||
'drop': preprocess_drop_batch,
|
||||
'nqopen': preprocess_sqaud_abstractive_batch,
|
||||
# 'multirc': preprocess_boolq_batch,
|
||||
'boolq_np': preprocess_boolq_batch,
|
||||
'openbookqa': preprocess_openbookqa_batch,
|
||||
'mctest': preprocess_race_batch,
|
||||
'social_iqa': preprocess_social_iqa_batch,
|
||||
'dream': preprocess_dream_batch,
|
||||
}
|
||||
return mapping[dataset_name]
|
||||
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
dataset_name_to_metric = {
|
||||
'squad': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
|
||||
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
|
||||
'boolq': 'metric/accuracy.py',
|
||||
'narrativeqa': 'metric/rouge_local/rouge_metric.py',
|
||||
'race': 'metric/accuracy.py',
|
||||
'quoref': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'ropes': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'drop': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'nqopen': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'boolq_np': 'metric/accuracy.py',
|
||||
'openbookqa': 'metric/accuracy.py',
|
||||
'mctest': 'metric/accuracy.py',
|
||||
'social_iqa': 'metric/accuracy.py',
|
||||
'dream': 'metric/accuracy.py',
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
if data_args.source_prefix is None and model_args.model_name_or_path in [
|
||||
"t5-small",
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
"t5-3b",
|
||||
"t5-11b",
|
||||
]:
|
||||
logger.warning(
|
||||
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
|
||||
"`--source_prefix 'translate English to German: ' `"
|
||||
)
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
|
||||
# '[OPTIONS]'])
|
||||
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
|
||||
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
|
||||
'[OPTIONS]']}
|
||||
|
||||
tokenizer.add_tokens(tokens_to_add)
|
||||
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
added_tokens = tokenizer.get_added_vocab()
|
||||
logger.info('Added tokens: {}'.format(added_tokens))
|
||||
|
||||
model = PromptT5.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
# task_num = data_args.max_task_num,
|
||||
# prompt_num = data_args.prompt_number,
|
||||
# format_num = data_args.qa_task_type_num,
|
||||
# add_task_prompt = False
|
||||
)
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
|
||||
#reload format specific task-prompt for newly involved task
|
||||
#format_prompts###task_promptsf
|
||||
|
||||
data_args.reload_from_trained_prompt = False#@
|
||||
data_args.load_from_format_task_id = False#@
|
||||
### before pretrain come !!!!!!
|
||||
|
||||
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
|
||||
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
||||
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
||||
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
|
||||
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
|
||||
# print(dataset2format[data_args.dataset_name])
|
||||
# print(data_args.dataset_name)
|
||||
elif data_args.reload_from_trained_prompt:
|
||||
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
|
||||
prompt_info = torch.load(data_args.trained_prompt_path)
|
||||
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
|
||||
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
|
||||
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
||||
|
||||
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
||||
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
|
||||
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
|
||||
format_id = format2id[dataset2format[data_args.dataset_name]]
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
|
||||
logger.info(
|
||||
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
|
||||
|
||||
# Set decoder_start_token_id
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||
if isinstance(tokenizer, MBartTokenizer):
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
|
||||
else:
|
||||
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
|
||||
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
|
||||
|
||||
if training_args.local_rank == -1 or training_args.no_cuda:
|
||||
device = torch.device("cuda")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Temporarily set max_target_length for training.
|
||||
max_target_length = data_args.max_target_length
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||
logger.warning(
|
||||
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
|
||||
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
|
||||
)
|
||||
|
||||
|
||||
question_column = data_args.question_column
|
||||
context_column = data_args.context_column
|
||||
answer_column = data_args.answer_column
|
||||
# import random
|
||||
if data_args.max_source_length > tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
|
||||
# Data collator
|
||||
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
if data_args.pad_to_max_length:
|
||||
data_collator = default_data_collator
|
||||
else:
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
model=model,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
)
|
||||
|
||||
#start
|
||||
train_dataloaders = {}
|
||||
eval_dataloaders = {}
|
||||
replay_dataloaders = {}
|
||||
all_replay = None
|
||||
|
||||
for ds_name in ['squad','narrativeqa','race','newsqa','quoref','drop','nqopen','openbookqa','mctest','social_iqa','dream']:
|
||||
eval_dataloaders[ds_name] = (load_from_disk("./oursfinallong/{}-eval.hf".format(ds_name)),load_from_disk("./oursfinallong/{}-evalex.hf".format(ds_name)))
|
||||
|
||||
|
||||
pre_tasks = []
|
||||
pre_general = []
|
||||
pre_test = []
|
||||
max_length = (
|
||||
training_args.generation_max_length
|
||||
if training_args.generation_max_length is not None
|
||||
else data_args.val_max_target_length
|
||||
)
|
||||
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||
|
||||
task_sequence = ['squad','newsqa','narrativeqa','nqopen','race','openbookqa','mctest','social_iqa']
|
||||
# task_sequence = ["woz.en","srl","sst","wikisql","squad"]
|
||||
fileout = open("diana_log.txt",'w')
|
||||
all_replay = None
|
||||
all_features = []
|
||||
all_ids = []
|
||||
cluster_num=0
|
||||
for cur_dataset in task_sequence:
|
||||
p=200
|
||||
if p>len(load_from_disk("./oursfinallong/{}-train.hf".format(cur_dataset))):
|
||||
p = len(load_from_disk("./oursfinallong/{}-train.hf".format(cur_dataset)))
|
||||
|
||||
trainds = load_from_disk("./oursfinallong/{}-train.hf".format(cur_dataset))
|
||||
cluster_num+=5
|
||||
pre_tasks.append(cur_dataset)
|
||||
if cur_dataset==task_sequence[-1]:
|
||||
pre_tasks.extend(["drop","quoref","dream"])
|
||||
data_args.dataset_name = cur_dataset
|
||||
logger.info("current_dataset:"+cur_dataset)
|
||||
|
||||
|
||||
training_args.do_train = True
|
||||
training_args.to_eval = False
|
||||
metric = load_metric(dataset_name_to_metric[cur_dataset])
|
||||
if all_replay is not None:
|
||||
fused = datasets.concatenate_datasets([all_replay,trainds])
|
||||
else:
|
||||
fused = trainds
|
||||
training_args.num_train_epochs = 5
|
||||
model.encoder.reset_train_count()
|
||||
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=fused,
|
||||
eval_dataset=None,
|
||||
eval_examples=None,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
train_result = trainer.train()
|
||||
|
||||
if training_args.local_rank<=0:
|
||||
save_set,features = save_load_diverse_sample(model,trainds)
|
||||
|
||||
if all_replay is None:
|
||||
all_replay = save_set
|
||||
else:
|
||||
all_replay = datasets.concatenate_datasets([all_replay,save_set])
|
||||
|
||||
if all_features==[]:
|
||||
all_features=features
|
||||
else:
|
||||
all_features.extend(features)
|
||||
np.save("./all_features.npy",np.array(all_features))
|
||||
|
||||
all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset))
|
||||
|
||||
|
||||
|
||||
if training_args.local_rank!=-1:
|
||||
torch.distributed.barrier()
|
||||
all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset))
|
||||
all_ids.extend([task2id[cur_dataset]]*50)
|
||||
all_features=np.load("./all_features.npy").tolist()
|
||||
|
||||
|
||||
|
||||
model.encoder.reset_train_count()
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=all_replay,
|
||||
eval_dataset=None,
|
||||
eval_examples=None,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
train_result = trainer.train()
|
||||
|
||||
model.encoder.add_negs(all_ids,all_features)
|
||||
|
||||
for pre_dataset in pre_tasks:
|
||||
|
||||
|
||||
data_args.dataset_name = pre_dataset
|
||||
|
||||
metric = load_metric(dataset_name_to_metric[pre_dataset])
|
||||
eval_dataset,eval_examples = eval_dataloaders[pre_dataset]
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=None,
|
||||
eval_dataset=eval_dataset,
|
||||
eval_examples=eval_examples,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("*** Evaluate:{} ***".format(data_args.dataset_name))
|
||||
|
||||
max_length, num_beams, ignore_keys_for_eval = None, None, None
|
||||
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval")
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
if training_args.local_rank<=0:
|
||||
try:
|
||||
print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout)
|
||||
print(metrics,file=fileout)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
|
||||
if len(languages) > 0:
|
||||
kwargs["language"] = languages
|
||||
return None
|
||||
|
||||
|
||||
# -
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,53 @@
|
|||
function fulldata_hfdata() {
|
||||
SAVING_PATH=$1
|
||||
PRETRAIN_MODEL_PATH=$2
|
||||
TRAIN_BATCH_SIZE=$3
|
||||
EVAL_BATCH_SIZE=$4
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
mkdir -p ${SAVING_PATH}
|
||||
|
||||
python -m torch.distributed.launch --nproc_per_node=4 diana.py \
|
||||
--model_name_or_path ${PRETRAIN_MODEL_PATH} \
|
||||
--output_dir ${SAVING_PATH} \
|
||||
--dataset_name squad \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--per_device_train_batch_size ${TRAIN_BATCH_SIZE} \
|
||||
--per_device_eval_batch_size ${EVAL_BATCH_SIZE} \
|
||||
--overwrite_output_dir \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 5 \
|
||||
--warmup_ratio 0.1 \
|
||||
--logging_steps 5 \
|
||||
--learning_rate 1e-4 \
|
||||
--predict_with_generate \
|
||||
--num_beams 4 \
|
||||
--save_strategy no \
|
||||
--evaluation_strategy no \
|
||||
--weight_decay 1e-2 \
|
||||
--max_source_length 512 \
|
||||
--label_smoothing_factor 0.1 \
|
||||
--do_lowercase True \
|
||||
--load_best_model_at_end True \
|
||||
--greater_is_better True \
|
||||
--save_total_limit 10 \
|
||||
--ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log
|
||||
}
|
||||
|
||||
SAVING_PATH=$1
|
||||
TRAIN_BATCH_SIZE=$2
|
||||
EVAL_BATCH_SIZE=$3
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
MODE=try
|
||||
|
||||
|
||||
SAVING_PATH=${SAVING_PATH}/lifelong/${MODE}
|
||||
fulldata_hfdata ${SAVING_PATH} ./t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE}
|
|
@ -0,0 +1,798 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
import logging
|
||||
import os
|
||||
os.environ['NCCL_ASYNC_ERROR_HANDLING']='1'
|
||||
import torch
|
||||
import copy,random
|
||||
import sys
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from sklearn.cluster import KMeans
|
||||
# from data import QAData
|
||||
sys.path.append("../")
|
||||
from models.metat5nometa import T5ForConditionalGeneration as PromptT5
|
||||
from downstream.dataset_processors import *
|
||||
from downstream.l2ptrainer import QuestionAnsweringTrainer
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
M2M100Tokenizer,
|
||||
MBart50Tokenizer,
|
||||
EvalPrediction,
|
||||
MBart50TokenizerFast,
|
||||
MBartTokenizer,
|
||||
MBartTokenizerFast,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
default_data_collator,
|
||||
set_seed,
|
||||
)
|
||||
from pathlib import Path
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
|
||||
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
|
||||
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
|
||||
format2dataset = {
|
||||
'extractive':['squad','extractive','newsqa','quoref'],
|
||||
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
|
||||
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
|
||||
}
|
||||
seed_datasets = ['race','narrativeqa','squad']
|
||||
dataset2format= {}
|
||||
|
||||
for k,vs in format2dataset.items():
|
||||
for v in vs:
|
||||
dataset2format[v] = k
|
||||
|
||||
task2id = {'squad': 0, 'narrativeqa': 1, 'race': 2, 'newsqa':3,'quoref':4,'drop':5,'nqopen':6,'openbookqa':7,'mctest':8,'social_iqa':9,'dream':10}
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
fix_t5: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to fix the main parameters of t5"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
|
||||
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
|
||||
"a jsonlines file."
|
||||
},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
after_train: Optional[str] = field(default=None)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=384,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||
"efficient on GPU but very bad for TPU."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||
"which is used during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
},
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
)
|
||||
forced_bos_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
|
||||
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
|
||||
"needs to be the target language token.(Usually it is the target language token)"
|
||||
},
|
||||
)
|
||||
do_lowercase: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
'help':'Whether to process input into lowercase'
|
||||
}
|
||||
)
|
||||
append_another_bos: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
'help':'Whether to add another bos'
|
||||
}
|
||||
)
|
||||
context_column: Optional[str] = field(
|
||||
default="context",
|
||||
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
|
||||
)
|
||||
question_column: Optional[str] = field(
|
||||
default="question",
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
answer_column: Optional[str] = field(
|
||||
default="answers",
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
max_task_num: Optional[int] = field(
|
||||
default=30,
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
qa_task_type_num: Optional[int] = field(
|
||||
default=4,
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
prompt_number: Optional[int] = field(
|
||||
default=100,
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
add_task_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
reload_from_trained_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to reload prompt from trained prompt"},
|
||||
)
|
||||
trained_prompt_path:Optional[str] = field(
|
||||
default = None,
|
||||
metadata={
|
||||
'help':'the path storing trained prompt embedding'
|
||||
}
|
||||
)
|
||||
load_from_format_task_id: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
|
||||
)
|
||||
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
# elif self.source_lang is None or self.target_lang is None:
|
||||
# raise ValueError("Need to specify the source language and the target language.")
|
||||
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
# assert extension == "json", "`train_file` should be a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
# assert extension == "json", "`validation_file` should be a json file."
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
|
||||
# +
|
||||
def main():
|
||||
|
||||
|
||||
|
||||
def preprocess_validation_function(examples):
|
||||
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
model_inputs["example_id"] = []
|
||||
|
||||
for i in range(len(model_inputs["input_ids"])):
|
||||
# One example can give several spans, this is the index of the example containing this span of text.
|
||||
sample_index = i #sample_mapping[i]
|
||||
model_inputs["example_id"].append(examples["id"][sample_index])
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
if True:
|
||||
pass
|
||||
# logger.info(f'Loading task {data_args.dataset_name} prompt')
|
||||
# format_id = format2id[dataset2format[data_args.dataset_name]]
|
||||
# task_id = task2id[data_args.dataset_name]
|
||||
# format_prompt_ids = [- (i + 1) for i in range(format_id * data_args.prompt_number, (
|
||||
# format_id + 1) * data_args.prompt_number)] # list(range(-(format_id * data_args.prompt_number+1), -((format_id + 1) * data_args.prompt_number+1)))
|
||||
# task_prompt_id_start = len(format2id.keys()) * data_args.prompt_number
|
||||
# logger.info('Prompt ids {}: {}'.format(task_prompt_id_start + task_id * data_args.prompt_number,
|
||||
# task_prompt_id_start + (task_id + 1) * data_args.prompt_number))
|
||||
# task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * data_args.prompt_number,
|
||||
# task_prompt_id_start + (task_id + 1) * data_args.prompt_number)]
|
||||
# input_ids = copy.deepcopy(
|
||||
# [format_prompt_ids + task_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
|
||||
#input_ids = copy.deepcopy([format_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
|
||||
# model_inputs['input_ids'] = input_ids
|
||||
# model_inputs['attention_mask'] = [[1] * data_args.prompt_number * 2 + attention_mask for attention_mask in
|
||||
# model_inputs['attention_mask']]
|
||||
|
||||
# input_ids = copy.deepcopy([input])
|
||||
return model_inputs
|
||||
def compute_metrics(p: EvalPrediction):
|
||||
return metric.compute(predictions=p.predictions, references=p.label_ids)
|
||||
|
||||
|
||||
def save_prompt_embedding(model,path):
|
||||
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
|
||||
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
|
||||
prompt_path = os.path.join(path,'prompt_embedding_info')
|
||||
torch.save(save_prompt_info,prompt_path)
|
||||
logger.info(f'Saving prompt embedding information to {prompt_path}')
|
||||
|
||||
|
||||
|
||||
def preprocess_function(examples):
|
||||
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
def save_load_diverse_sample(model,trainset):
|
||||
features = []
|
||||
device='cuda:0'
|
||||
with torch.no_grad():
|
||||
sub_set = trainset.select(range(50))
|
||||
|
||||
for idx,item in enumerate(sub_set):
|
||||
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
|
||||
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
|
||||
return_dict=True)
|
||||
features.append(query.detach().cpu().numpy())
|
||||
|
||||
|
||||
|
||||
return sub_set,features
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def dataset_name_to_func(dataset_name):
|
||||
mapping = {
|
||||
'squad': preprocess_sqaud_batch,
|
||||
'squad_v2': preprocess_sqaud_batch,
|
||||
'boolq': preprocess_boolq_batch,
|
||||
'narrativeqa': preprocess_narrativeqa_batch,
|
||||
'race': preprocess_race_batch,
|
||||
'newsqa': preprocess_newsqa_batch,
|
||||
'quoref': preprocess_sqaud_batch,
|
||||
'ropes': preprocess_ropes_batch,
|
||||
'drop': preprocess_drop_batch,
|
||||
'nqopen': preprocess_sqaud_abstractive_batch,
|
||||
# 'multirc': preprocess_boolq_batch,
|
||||
'boolq_np': preprocess_boolq_batch,
|
||||
'openbookqa': preprocess_openbookqa_batch,
|
||||
'mctest': preprocess_race_batch,
|
||||
'social_iqa': preprocess_social_iqa_batch,
|
||||
'dream': preprocess_dream_batch,
|
||||
}
|
||||
return mapping[dataset_name]
|
||||
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
dataset_name_to_metric = {
|
||||
'squad': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
|
||||
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
|
||||
'boolq': 'metric/accuracy.py',
|
||||
'narrativeqa': 'metric/rouge_local/rouge_metric.py',
|
||||
'race': 'metric/accuracy.py',
|
||||
'quoref': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'ropes': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'drop': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'nqopen': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'boolq_np': 'metric/accuracy.py',
|
||||
'openbookqa': 'metric/accuracy.py',
|
||||
'mctest': 'metric/accuracy.py',
|
||||
'social_iqa': 'metric/accuracy.py',
|
||||
'dream': 'metric/accuracy.py',
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
if data_args.source_prefix is None and model_args.model_name_or_path in [
|
||||
"t5-small",
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
"t5-3b",
|
||||
"t5-11b",
|
||||
]:
|
||||
logger.warning(
|
||||
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
|
||||
"`--source_prefix 'translate English to German: ' `"
|
||||
)
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
|
||||
# '[OPTIONS]'])
|
||||
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
|
||||
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
|
||||
'[OPTIONS]']}
|
||||
|
||||
tokenizer.add_tokens(tokens_to_add)
|
||||
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
added_tokens = tokenizer.get_added_vocab()
|
||||
logger.info('Added tokens: {}'.format(added_tokens))
|
||||
|
||||
model = PromptT5.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
# task_num = data_args.max_task_num,
|
||||
# prompt_num = data_args.prompt_number,
|
||||
# format_num = data_args.qa_task_type_num,
|
||||
# add_task_prompt = False
|
||||
)
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
|
||||
#reload format specific task-prompt for newly involved task
|
||||
#format_prompts###task_promptsf
|
||||
|
||||
data_args.reload_from_trained_prompt = False#@
|
||||
data_args.load_from_format_task_id = False#@
|
||||
### before pretrain come !!!!!!
|
||||
|
||||
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
|
||||
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
||||
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
||||
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
|
||||
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
|
||||
# print(dataset2format[data_args.dataset_name])
|
||||
# print(data_args.dataset_name)
|
||||
elif data_args.reload_from_trained_prompt:
|
||||
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
|
||||
prompt_info = torch.load(data_args.trained_prompt_path)
|
||||
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
|
||||
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
|
||||
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
||||
|
||||
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
||||
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
|
||||
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
|
||||
format_id = format2id[dataset2format[data_args.dataset_name]]
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
|
||||
logger.info(
|
||||
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
|
||||
|
||||
# Set decoder_start_token_id
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||
if isinstance(tokenizer, MBartTokenizer):
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
|
||||
else:
|
||||
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
|
||||
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
|
||||
|
||||
if training_args.local_rank == -1 or training_args.no_cuda:
|
||||
device = torch.device("cuda")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Temporarily set max_target_length for training.
|
||||
max_target_length = data_args.max_target_length
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||
logger.warning(
|
||||
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
|
||||
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
|
||||
)
|
||||
|
||||
|
||||
question_column = data_args.question_column
|
||||
context_column = data_args.context_column
|
||||
answer_column = data_args.answer_column
|
||||
# import random
|
||||
if data_args.max_source_length > tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
|
||||
# Data collator
|
||||
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
if data_args.pad_to_max_length:
|
||||
data_collator = default_data_collator
|
||||
else:
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
model=model,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
)
|
||||
|
||||
#start
|
||||
train_dataloaders = {}
|
||||
eval_dataloaders = {}
|
||||
replay_dataloaders = {}
|
||||
all_replay = None
|
||||
|
||||
for ds_name in ['squad','narrativeqa','race','newsqa','quoref','drop','nqopen','openbookqa','mctest','social_iqa','dream']:
|
||||
eval_dataloaders[ds_name] = (load_from_disk("./oursnometa/{}-eval.hf".format(ds_name)),load_from_disk("./oursnometa/{}-evalex.hf".format(ds_name)))
|
||||
|
||||
|
||||
pre_tasks = []
|
||||
pre_general = []
|
||||
pre_test = []
|
||||
max_length = (
|
||||
training_args.generation_max_length
|
||||
if training_args.generation_max_length is not None
|
||||
else data_args.val_max_target_length
|
||||
)
|
||||
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||
|
||||
task_sequence = ['squad','newsqa','narrativeqa','nqopen','race','openbookqa','mctest','social_iqa']
|
||||
fileout = open("diana_log.txt",'w')
|
||||
all_replay = None
|
||||
all_features = []
|
||||
all_ids = []
|
||||
cluster_num=0
|
||||
for cur_dataset in task_sequence:
|
||||
|
||||
trainds = load_from_disk("./oursnometa/{}-train.hf".format(cur_dataset)).select(range(1000))
|
||||
cluster_num+=5
|
||||
pre_tasks.append(cur_dataset)
|
||||
if cur_dataset==task_sequence[-1]:
|
||||
pre_tasks.extend(["drop","quoref","dream"])
|
||||
data_args.dataset_name = cur_dataset
|
||||
logger.info("current_dataset:"+cur_dataset)
|
||||
|
||||
|
||||
training_args.do_train = True
|
||||
training_args.to_eval = False
|
||||
metric = load_metric(dataset_name_to_metric[cur_dataset])
|
||||
if all_replay is not None:
|
||||
fused = datasets.concatenate_datasets([all_replay,trainds])
|
||||
else:
|
||||
fused = trainds
|
||||
training_args.num_train_epochs = 5
|
||||
model.encoder.reset_train_count()
|
||||
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=fused,
|
||||
eval_dataset=None,
|
||||
eval_examples=None,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
train_result = trainer.train()
|
||||
|
||||
if training_args.local_rank<=0:
|
||||
try:
|
||||
save_set,features = save_load_diverse_sample(model,trainds)
|
||||
|
||||
if all_replay is None:
|
||||
all_replay = save_set
|
||||
else:
|
||||
all_replay = datasets.concatenate_datasets([all_replay,save_set])
|
||||
|
||||
if all_features==[]:
|
||||
all_features=features
|
||||
else:
|
||||
all_features.extend(features)
|
||||
np.save("./all_features.npy",np.array(all_features))
|
||||
|
||||
all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
if training_args.local_rank!=-1:
|
||||
torch.distributed.barrier()
|
||||
all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset))
|
||||
all_ids.extend([task2id[cur_dataset]]*50)
|
||||
all_features=np.load("./all_features.npy").tolist()
|
||||
|
||||
|
||||
|
||||
model.encoder.reset_train_count()
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=all_replay,
|
||||
eval_dataset=None,
|
||||
eval_examples=None,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
train_result = trainer.train()
|
||||
|
||||
model.encoder.add_negs(all_ids,all_features)
|
||||
|
||||
for pre_dataset in pre_tasks:
|
||||
|
||||
|
||||
data_args.dataset_name = pre_dataset
|
||||
|
||||
metric = load_metric(dataset_name_to_metric[pre_dataset])
|
||||
eval_dataset,eval_examples = eval_dataloaders[pre_dataset]
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=None,
|
||||
eval_dataset=eval_dataset,
|
||||
eval_examples=eval_examples,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("*** Evaluate:{} ***".format(data_args.dataset_name))
|
||||
|
||||
max_length, num_beams, ignore_keys_for_eval = None, None, None
|
||||
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval")
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
if training_args.local_rank<=0:
|
||||
try:
|
||||
print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout)
|
||||
print(metrics,file=fileout)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
|
||||
if len(languages) > 0:
|
||||
kwargs["language"] = languages
|
||||
return None
|
||||
|
||||
|
||||
# -
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,53 @@
|
|||
function fulldata_hfdata() {
|
||||
SAVING_PATH=$1
|
||||
PRETRAIN_MODEL_PATH=$2
|
||||
TRAIN_BATCH_SIZE=$3
|
||||
EVAL_BATCH_SIZE=$4
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
mkdir -p ${SAVING_PATH}
|
||||
|
||||
python -m torch.distributed.launch --nproc_per_node=4 dianawometa.py \
|
||||
--model_name_or_path ${PRETRAIN_MODEL_PATH} \
|
||||
--output_dir ${SAVING_PATH} \
|
||||
--dataset_name squad \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--per_device_train_batch_size ${TRAIN_BATCH_SIZE} \
|
||||
--per_device_eval_batch_size ${EVAL_BATCH_SIZE} \
|
||||
--overwrite_output_dir \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 5 \
|
||||
--warmup_ratio 0.1 \
|
||||
--logging_steps 5 \
|
||||
--learning_rate 1e-4 \
|
||||
--predict_with_generate \
|
||||
--num_beams 4 \
|
||||
--save_strategy no \
|
||||
--evaluation_strategy no \
|
||||
--weight_decay 1e-2 \
|
||||
--max_source_length 512 \
|
||||
--label_smoothing_factor 0.1 \
|
||||
--do_lowercase True \
|
||||
--load_best_model_at_end True \
|
||||
--greater_is_better True \
|
||||
--save_total_limit 10 \
|
||||
--ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log
|
||||
}
|
||||
|
||||
SAVING_PATH=$1
|
||||
TRAIN_BATCH_SIZE=$2
|
||||
EVAL_BATCH_SIZE=$3
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
MODE=try
|
||||
|
||||
|
||||
SAVING_PATH=${SAVING_PATH}/lifelong/${MODE}
|
||||
fulldata_hfdata ${SAVING_PATH} ./t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE}
|
|
@ -0,0 +1,827 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
import logging
|
||||
import os
|
||||
os.environ['NCCL_ASYNC_ERROR_HANDLING']='1'
|
||||
import torch
|
||||
import copy,random
|
||||
import sys
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from sklearn.cluster import KMeans
|
||||
# from data import QAData
|
||||
sys.path.append("../")
|
||||
from models.metat5notask import T5ForConditionalGeneration as PromptT5
|
||||
from downstream.dataset_processors import *
|
||||
from downstream.l2ptrainer import QuestionAnsweringTrainer
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
M2M100Tokenizer,
|
||||
MBart50Tokenizer,
|
||||
EvalPrediction,
|
||||
MBart50TokenizerFast,
|
||||
MBartTokenizer,
|
||||
MBartTokenizerFast,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
default_data_collator,
|
||||
set_seed,
|
||||
)
|
||||
from pathlib import Path
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
|
||||
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
|
||||
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
|
||||
format2dataset = {
|
||||
'extractive':['squad','extractive','newsqa','quoref'],
|
||||
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
|
||||
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
|
||||
}
|
||||
seed_datasets = ['race','narrativeqa','squad']
|
||||
dataset2format= {}
|
||||
|
||||
for k,vs in format2dataset.items():
|
||||
for v in vs:
|
||||
dataset2format[v] = k
|
||||
|
||||
task2id = {'squad': 0, 'narrativeqa': 1, 'race': 2, 'newsqa':3,'quoref':4,'drop':5,'nqopen':6,'openbookqa':7,'mctest':8,'social_iqa':9,'dream':10}
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
fix_t5: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to fix the main parameters of t5"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
|
||||
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
|
||||
"a jsonlines file."
|
||||
},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
after_train: Optional[str] = field(default=None)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=384,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||
"efficient on GPU but very bad for TPU."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||
"which is used during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
},
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
)
|
||||
forced_bos_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
|
||||
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
|
||||
"needs to be the target language token.(Usually it is the target language token)"
|
||||
},
|
||||
)
|
||||
do_lowercase: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
'help':'Whether to process input into lowercase'
|
||||
}
|
||||
)
|
||||
append_another_bos: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
'help':'Whether to add another bos'
|
||||
}
|
||||
)
|
||||
context_column: Optional[str] = field(
|
||||
default="context",
|
||||
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
|
||||
)
|
||||
question_column: Optional[str] = field(
|
||||
default="question",
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
answer_column: Optional[str] = field(
|
||||
default="answers",
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
max_task_num: Optional[int] = field(
|
||||
default=30,
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
qa_task_type_num: Optional[int] = field(
|
||||
default=4,
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
prompt_number: Optional[int] = field(
|
||||
default=100,
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
add_task_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
reload_from_trained_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to reload prompt from trained prompt"},
|
||||
)
|
||||
trained_prompt_path:Optional[str] = field(
|
||||
default = None,
|
||||
metadata={
|
||||
'help':'the path storing trained prompt embedding'
|
||||
}
|
||||
)
|
||||
load_from_format_task_id: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
|
||||
)
|
||||
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
# elif self.source_lang is None or self.target_lang is None:
|
||||
# raise ValueError("Need to specify the source language and the target language.")
|
||||
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
# assert extension == "json", "`train_file` should be a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
# assert extension == "json", "`validation_file` should be a json file."
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
|
||||
# +
|
||||
def main():
|
||||
|
||||
|
||||
|
||||
def preprocess_validation_function(examples):
|
||||
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
model_inputs["example_id"] = []
|
||||
|
||||
for i in range(len(model_inputs["input_ids"])):
|
||||
# One example can give several spans, this is the index of the example containing this span of text.
|
||||
sample_index = i #sample_mapping[i]
|
||||
model_inputs["example_id"].append(examples["id"][sample_index])
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
if True:
|
||||
pass
|
||||
# logger.info(f'Loading task {data_args.dataset_name} prompt')
|
||||
# format_id = format2id[dataset2format[data_args.dataset_name]]
|
||||
# task_id = task2id[data_args.dataset_name]
|
||||
# format_prompt_ids = [- (i + 1) for i in range(format_id * data_args.prompt_number, (
|
||||
# format_id + 1) * data_args.prompt_number)] # list(range(-(format_id * data_args.prompt_number+1), -((format_id + 1) * data_args.prompt_number+1)))
|
||||
# task_prompt_id_start = len(format2id.keys()) * data_args.prompt_number
|
||||
# logger.info('Prompt ids {}: {}'.format(task_prompt_id_start + task_id * data_args.prompt_number,
|
||||
# task_prompt_id_start + (task_id + 1) * data_args.prompt_number))
|
||||
# task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * data_args.prompt_number,
|
||||
# task_prompt_id_start + (task_id + 1) * data_args.prompt_number)]
|
||||
# input_ids = copy.deepcopy(
|
||||
# [format_prompt_ids + task_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
|
||||
#input_ids = copy.deepcopy([format_prompt_ids + input_ids for input_ids in model_inputs['input_ids']])
|
||||
# model_inputs['input_ids'] = input_ids
|
||||
# model_inputs['attention_mask'] = [[1] * data_args.prompt_number * 2 + attention_mask for attention_mask in
|
||||
# model_inputs['attention_mask']]
|
||||
|
||||
# input_ids = copy.deepcopy([input])
|
||||
return model_inputs
|
||||
def compute_metrics(p: EvalPrediction):
|
||||
return metric.compute(predictions=p.predictions, references=p.label_ids)
|
||||
|
||||
|
||||
def save_prompt_embedding(model,path):
|
||||
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
|
||||
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
|
||||
prompt_path = os.path.join(path,'prompt_embedding_info')
|
||||
torch.save(save_prompt_info,prompt_path)
|
||||
logger.info(f'Saving prompt embedding information to {prompt_path}')
|
||||
|
||||
|
||||
|
||||
def preprocess_function(examples):
|
||||
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
def save_load_diverse_sample(model,trainset):
|
||||
with torch.no_grad():
|
||||
device = 'cuda:0'
|
||||
search_upbound = len(trainset)//4
|
||||
query_idxs = [None]*30
|
||||
keys = model.encoder.domain_keys
|
||||
for idx,item in enumerate(trainset.select(range(search_upbound))):
|
||||
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
|
||||
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
|
||||
return_dict=True)
|
||||
result = torch.matmul(query,keys.t())
|
||||
result = torch.topk(result,5).indices[0].cpu().numpy().tolist()
|
||||
key_sel = None
|
||||
for key_idx in result:
|
||||
if query_idxs[key_idx] is None or len(query_idxs[key_idx])<3:
|
||||
key_sel = key_idx
|
||||
break
|
||||
if key_sel is not None:
|
||||
if query_idxs[key_sel] is None:
|
||||
query_idxs[key_sel] = [idx]
|
||||
else:
|
||||
query_idxs[key_sel].append(idx)
|
||||
total_idxs = []
|
||||
for item in query_idxs:
|
||||
try:
|
||||
total_idxs.extend(item[:3])
|
||||
except:
|
||||
total_idxs.extend(random.sample(list(range(search_upbound,len(trainset))),3))
|
||||
total_idxs = list(set(total_idxs))
|
||||
total_idxs = random.sample(total_idxs,50)
|
||||
sub_set = trainset.select(total_idxs)
|
||||
|
||||
features = []
|
||||
for idx,item in enumerate(sub_set):
|
||||
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
|
||||
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
|
||||
return_dict=True)
|
||||
features.append(query.detach().cpu().numpy())
|
||||
|
||||
|
||||
|
||||
return sub_set,features
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def dataset_name_to_func(dataset_name):
|
||||
mapping = {
|
||||
'squad': preprocess_sqaud_batch,
|
||||
'squad_v2': preprocess_sqaud_batch,
|
||||
'boolq': preprocess_boolq_batch,
|
||||
'narrativeqa': preprocess_narrativeqa_batch,
|
||||
'race': preprocess_race_batch,
|
||||
'newsqa': preprocess_newsqa_batch,
|
||||
'quoref': preprocess_sqaud_batch,
|
||||
'ropes': preprocess_ropes_batch,
|
||||
'drop': preprocess_drop_batch,
|
||||
'nqopen': preprocess_sqaud_abstractive_batch,
|
||||
# 'multirc': preprocess_boolq_batch,
|
||||
'boolq_np': preprocess_boolq_batch,
|
||||
'openbookqa': preprocess_openbookqa_batch,
|
||||
'mctest': preprocess_race_batch,
|
||||
'social_iqa': preprocess_social_iqa_batch,
|
||||
'dream': preprocess_dream_batch,
|
||||
}
|
||||
return mapping[dataset_name]
|
||||
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
dataset_name_to_metric = {
|
||||
'squad': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
|
||||
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
|
||||
'boolq': 'metric/accuracy.py',
|
||||
'narrativeqa': 'metric/rouge_local/rouge_metric.py',
|
||||
'race': 'metric/accuracy.py',
|
||||
'quoref': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'ropes': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'drop': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'nqopen': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'boolq_np': 'metric/accuracy.py',
|
||||
'openbookqa': 'metric/accuracy.py',
|
||||
'mctest': 'metric/accuracy.py',
|
||||
'social_iqa': 'metric/accuracy.py',
|
||||
'dream': 'metric/accuracy.py',
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
if data_args.source_prefix is None and model_args.model_name_or_path in [
|
||||
"t5-small",
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
"t5-3b",
|
||||
"t5-11b",
|
||||
]:
|
||||
logger.warning(
|
||||
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
|
||||
"`--source_prefix 'translate English to German: ' `"
|
||||
)
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
|
||||
# '[OPTIONS]'])
|
||||
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
|
||||
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
|
||||
'[OPTIONS]']}
|
||||
|
||||
tokenizer.add_tokens(tokens_to_add)
|
||||
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
added_tokens = tokenizer.get_added_vocab()
|
||||
logger.info('Added tokens: {}'.format(added_tokens))
|
||||
|
||||
model = PromptT5.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
# task_num = data_args.max_task_num,
|
||||
# prompt_num = data_args.prompt_number,
|
||||
# format_num = data_args.qa_task_type_num,
|
||||
# add_task_prompt = False
|
||||
)
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
|
||||
#reload format specific task-prompt for newly involved task
|
||||
#format_prompts###task_promptsf
|
||||
|
||||
data_args.reload_from_trained_prompt = False#@
|
||||
data_args.load_from_format_task_id = False#@
|
||||
### before pretrain come !!!!!!
|
||||
|
||||
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
|
||||
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
||||
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
||||
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
|
||||
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
|
||||
# print(dataset2format[data_args.dataset_name])
|
||||
# print(data_args.dataset_name)
|
||||
elif data_args.reload_from_trained_prompt:
|
||||
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
|
||||
prompt_info = torch.load(data_args.trained_prompt_path)
|
||||
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
|
||||
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
|
||||
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
||||
|
||||
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
||||
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
|
||||
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
|
||||
format_id = format2id[dataset2format[data_args.dataset_name]]
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
|
||||
logger.info(
|
||||
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
|
||||
|
||||
# Set decoder_start_token_id
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||
if isinstance(tokenizer, MBartTokenizer):
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
|
||||
else:
|
||||
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
|
||||
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
|
||||
|
||||
if training_args.local_rank == -1 or training_args.no_cuda:
|
||||
device = torch.device("cuda")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Temporarily set max_target_length for training.
|
||||
max_target_length = data_args.max_target_length
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||
logger.warning(
|
||||
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
|
||||
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
|
||||
)
|
||||
|
||||
|
||||
question_column = data_args.question_column
|
||||
context_column = data_args.context_column
|
||||
answer_column = data_args.answer_column
|
||||
# import random
|
||||
if data_args.max_source_length > tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
|
||||
# Data collator
|
||||
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
if data_args.pad_to_max_length:
|
||||
data_collator = default_data_collator
|
||||
else:
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
model=model,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
)
|
||||
|
||||
#start
|
||||
train_dataloaders = {}
|
||||
eval_dataloaders = {}
|
||||
replay_dataloaders = {}
|
||||
all_replay = None
|
||||
|
||||
for ds_name in ['squad','narrativeqa','race','newsqa','quoref','drop','nqopen','openbookqa','mctest','social_iqa','dream']:
|
||||
eval_dataloaders[ds_name] = (load_from_disk("./oursnotask/{}-eval.hf".format(ds_name)),load_from_disk("./oursnotask/{}-evalex.hf".format(ds_name)))
|
||||
|
||||
|
||||
pre_tasks = []
|
||||
pre_general = []
|
||||
pre_test = []
|
||||
max_length = (
|
||||
training_args.generation_max_length
|
||||
if training_args.generation_max_length is not None
|
||||
else data_args.val_max_target_length
|
||||
)
|
||||
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||
|
||||
task_sequence = ['squad','newsqa','narrativeqa','nqopen','race','openbookqa','mctest','social_iqa']
|
||||
# task_sequence = ["woz.en","srl","sst","wikisql","squad"]
|
||||
fileout = open("diana_log.txt",'w')
|
||||
all_replay = None
|
||||
all_features = []
|
||||
all_ids = []
|
||||
cluster_num=0
|
||||
for cur_dataset in task_sequence:
|
||||
p=200
|
||||
if p>len(load_from_disk("./oursnotask/{}-train.hf".format(cur_dataset))):
|
||||
p = len(load_from_disk("./oursnotask/{}-train.hf".format(cur_dataset)))
|
||||
|
||||
trainds = load_from_disk("./oursnotask/{}-train.hf".format(cur_dataset))
|
||||
cluster_num+=5
|
||||
pre_tasks.append(cur_dataset)
|
||||
if cur_dataset==task_sequence[-1]:
|
||||
pre_tasks.extend(["drop","quoref","dream"])
|
||||
data_args.dataset_name = cur_dataset
|
||||
logger.info("current_dataset:"+cur_dataset)
|
||||
|
||||
|
||||
training_args.do_train = True
|
||||
training_args.to_eval = False
|
||||
metric = load_metric(dataset_name_to_metric[cur_dataset])
|
||||
if all_replay is not None:
|
||||
fused = datasets.concatenate_datasets([all_replay,trainds])
|
||||
else:
|
||||
fused = trainds
|
||||
training_args.num_train_epochs = 5
|
||||
model.encoder.reset_train_count()
|
||||
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=fused,
|
||||
eval_dataset=None,
|
||||
eval_examples=None,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
train_result = trainer.train()
|
||||
|
||||
if training_args.local_rank<=0:
|
||||
save_set,features = save_load_diverse_sample(model,trainds)
|
||||
|
||||
if all_replay is None:
|
||||
all_replay = save_set
|
||||
else:
|
||||
all_replay = datasets.concatenate_datasets([all_replay,save_set])
|
||||
|
||||
if all_features==[]:
|
||||
all_features=features
|
||||
else:
|
||||
all_features.extend(features)
|
||||
np.save("./all_features.npy",np.array(all_features))
|
||||
|
||||
all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset))
|
||||
|
||||
|
||||
|
||||
if training_args.local_rank!=-1:
|
||||
torch.distributed.barrier()
|
||||
all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset))
|
||||
all_ids.extend([task2id[cur_dataset]]*50)
|
||||
all_features=np.load("./all_features.npy").tolist()
|
||||
|
||||
|
||||
|
||||
model.encoder.reset_train_count()
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=all_replay,
|
||||
eval_dataset=None,
|
||||
eval_examples=None,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
train_result = trainer.train()
|
||||
|
||||
model.encoder.add_negs(all_ids,all_features)
|
||||
|
||||
for pre_dataset in pre_tasks:
|
||||
|
||||
|
||||
data_args.dataset_name = pre_dataset
|
||||
|
||||
metric = load_metric(dataset_name_to_metric[pre_dataset])
|
||||
eval_dataset,eval_examples = eval_dataloaders[pre_dataset]
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=None,
|
||||
eval_dataset=eval_dataset,
|
||||
eval_examples=eval_examples,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("*** Evaluate:{} ***".format(data_args.dataset_name))
|
||||
|
||||
max_length, num_beams, ignore_keys_for_eval = None, None, None
|
||||
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval")
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
if training_args.local_rank<=0:
|
||||
try:
|
||||
print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout)
|
||||
print(metrics,file=fileout)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
|
||||
if len(languages) > 0:
|
||||
kwargs["language"] = languages
|
||||
return None
|
||||
|
||||
|
||||
# -
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,53 @@
|
|||
function fulldata_hfdata() {
|
||||
SAVING_PATH=$1
|
||||
PRETRAIN_MODEL_PATH=$2
|
||||
TRAIN_BATCH_SIZE=$3
|
||||
EVAL_BATCH_SIZE=$4
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
mkdir -p ${SAVING_PATH}
|
||||
|
||||
python -m torch.distributed.launch --nproc_per_node=4 dianawotask.py \
|
||||
--model_name_or_path ${PRETRAIN_MODEL_PATH} \
|
||||
--output_dir ${SAVING_PATH} \
|
||||
--dataset_name squad \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--per_device_train_batch_size ${TRAIN_BATCH_SIZE} \
|
||||
--per_device_eval_batch_size ${EVAL_BATCH_SIZE} \
|
||||
--overwrite_output_dir \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 5 \
|
||||
--warmup_ratio 0.1 \
|
||||
--logging_steps 5 \
|
||||
--learning_rate 1e-4 \
|
||||
--predict_with_generate \
|
||||
--num_beams 4 \
|
||||
--save_strategy no \
|
||||
--evaluation_strategy no \
|
||||
--weight_decay 1e-2 \
|
||||
--max_source_length 512 \
|
||||
--label_smoothing_factor 0.1 \
|
||||
--do_lowercase True \
|
||||
--load_best_model_at_end True \
|
||||
--greater_is_better True \
|
||||
--save_total_limit 10 \
|
||||
--ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log
|
||||
}
|
||||
|
||||
SAVING_PATH=$1
|
||||
TRAIN_BATCH_SIZE=$2
|
||||
EVAL_BATCH_SIZE=$3
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
MODE=try
|
||||
|
||||
|
||||
SAVING_PATH=${SAVING_PATH}/lifelong/${MODE}
|
||||
fulldata_hfdata ${SAVING_PATH} ./t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE}
|
|
@ -0,0 +1,556 @@
|
|||
from transformers import Seq2SeqTrainer, is_torch_tpu_available, EvalPrediction
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import nltk
|
||||
import datasets
|
||||
import re
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
from pathlib import Path
|
||||
import nltk
|
||||
# nltk.download('punkt')
|
||||
from transformers.trainer_utils import (
|
||||
PREFIX_CHECKPOINT_DIR,
|
||||
BestRun,
|
||||
EvalLoopOutput,
|
||||
EvalPrediction,
|
||||
HPSearchBackend,
|
||||
HubStrategy,
|
||||
IntervalStrategy,
|
||||
PredictionOutput,
|
||||
ShardedDDPOption,
|
||||
TrainerMemoryTracker,
|
||||
TrainOutput,
|
||||
default_compute_objective,
|
||||
default_hp_space,
|
||||
denumpify_detensorize,
|
||||
get_last_checkpoint,
|
||||
number_of_arguments,
|
||||
set_seed,
|
||||
speed_metrics,
|
||||
)
|
||||
import warnings
|
||||
from transformers.trainer_pt_utils import (
|
||||
DistributedLengthGroupedSampler,
|
||||
DistributedSamplerWithLoop,
|
||||
DistributedTensorGatherer,
|
||||
IterableDatasetShard,
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
SequentialDistributedSampler,
|
||||
ShardSampler,
|
||||
distributed_broadcast_scalars,
|
||||
distributed_concat,
|
||||
find_batch_size,
|
||||
get_parameter_names,
|
||||
nested_concat,
|
||||
nested_detach,
|
||||
nested_numpify,
|
||||
nested_truncate,
|
||||
nested_xla_mesh_reduce,
|
||||
reissue_pt_warnings,
|
||||
)
|
||||
from transformers.file_utils import (
|
||||
CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
get_full_repo_name,
|
||||
is_apex_available,
|
||||
is_datasets_available,
|
||||
is_in_notebook,
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_torch_tpu_available,
|
||||
)
|
||||
TRAINING_ARGS_NAME = "training_args.bin"
|
||||
TRAINER_STATE_NAME = "trainer_state.json"
|
||||
OPTIMIZER_NAME = "optimizer.pt"
|
||||
SCHEDULER_NAME = "scheduler.pt"
|
||||
SCALER_NAME = "scaler.pt"
|
||||
from transformers.trainer_utils import PredictionOutput,EvalLoopOutput
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.metrics as met
|
||||
|
||||
def fix_buggy_characters(str):
|
||||
return re.sub("[{}^\\\\`\u2047<]", " ", str)
|
||||
def replace_punctuation(str):
|
||||
return str.replace("\"", "").replace("'", "")
|
||||
def score_string_similarity(str1, str2):
|
||||
if str1 == str2:
|
||||
return 3.0 # Better than perfect token match
|
||||
str1 = fix_buggy_characters(replace_punctuation(str1))
|
||||
str2 = fix_buggy_characters(replace_punctuation(str2))
|
||||
if str1 == str2:
|
||||
return 2.0
|
||||
if " " in str1 or " " in str2:
|
||||
str1_split = str1.split(" ")
|
||||
str2_split = str2.split(" ")
|
||||
overlap = list(set(str1_split) & set(str2_split))
|
||||
return len(overlap) / max(len(str1_split), len(str2_split))
|
||||
else:
|
||||
if str1 == str2:
|
||||
return 1.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
class QuestionAnsweringTrainer(Seq2SeqTrainer):
|
||||
def __init__(self, *args, tokenizer,eval_examples=None,answer_column_name='answers',dataset_name='squad', **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.eval_examples = eval_examples
|
||||
self.answer_column_name = answer_column_name
|
||||
self.dataset_name = dataset_name
|
||||
self.tokenizer = tokenizer
|
||||
if dataset_name in ['squad', "quoref", 'ropes', 'drop', 'nqopen']:
|
||||
self.post_process_function = self._post_process_squad
|
||||
elif dataset_name in ['boolq', 'multirc', 'boolq_np']:
|
||||
self.post_process_function = self._post_process_boolq
|
||||
elif dataset_name== 'narrativeqa':
|
||||
self.post_process_function = self._post_process_narrative_qa
|
||||
elif dataset_name in ['race', 'mctest', 'social_iqa']:
|
||||
self.post_process_function = self._post_process_race
|
||||
elif dataset_name in ["squad_v2", "newsqa"]:
|
||||
self.post_process_function = self._post_process_squad_v2
|
||||
elif dataset_name in ['openbookqa']:
|
||||
self.post_process_function = self._post_process_openbookqa
|
||||
elif dataset_name in ['dream']:
|
||||
self.post_process_function = self._post_process_dream
|
||||
else:
|
||||
print(f"{dataset_name} has not QuestionAnsweringTrainer.post_process_function!")
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
if self.label_smoother is not None and "labels" in inputs:
|
||||
labels = inputs.pop("labels")
|
||||
inputs["train"]=True
|
||||
else:
|
||||
labels = None
|
||||
outputs,sim_loss = model(**inputs)
|
||||
|
||||
# print(model.encoder.)
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index]
|
||||
|
||||
if labels is not None:
|
||||
loss = self.label_smoother(outputs, labels)
|
||||
else:
|
||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
|
||||
loss += sim_loss*0.1
|
||||
# print(sim_loss)
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
|
||||
def _post_process_squad(
|
||||
self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval"
|
||||
,version_2_with_negative=False):
|
||||
# Decode the predicted tokens.
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
|
||||
# Build a map example to its corresponding features.
|
||||
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
|
||||
feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)}
|
||||
predictions = {}
|
||||
# Let's loop over all the examples!
|
||||
for example_index, example in enumerate(examples):
|
||||
# This is the index of the feature associated to the current example.
|
||||
try:
|
||||
feature_index = feature_per_example[example_index]
|
||||
except:
|
||||
print(feature_per_example)
|
||||
print(example_index)
|
||||
print(len(feature_per_example))
|
||||
predictions[example["id"]] = decoded_preds[feature_index]
|
||||
|
||||
# Format the result to the format the metric expects.
|
||||
if version_2_with_negative:
|
||||
formatted_predictions = [
|
||||
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
|
||||
]
|
||||
else:
|
||||
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
|
||||
|
||||
references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples]
|
||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||
|
||||
def _post_process_squad_v2(
|
||||
self, examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval"):
|
||||
# Decode the predicted tokens.
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
|
||||
# Build a map example to its corresponding features.
|
||||
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
|
||||
feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)}
|
||||
predictions = {}
|
||||
# Let's loop over all the examples!
|
||||
for example_index, example in enumerate(examples):
|
||||
# This is the index of the feature associated to the current example.
|
||||
feature_index = feature_per_example[example_index]
|
||||
predictions[example["id"]] = decoded_preds[feature_index]
|
||||
|
||||
# Format the result to the format the metric expects.
|
||||
formatted_predictions = [
|
||||
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
|
||||
]
|
||||
|
||||
references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples]
|
||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||
|
||||
@classmethod
|
||||
def postprocess_text(cls,preds, labels):
|
||||
preds = [pred.strip() for pred in preds]
|
||||
labels = [label.strip() for label in labels]
|
||||
# rougeLSum expects newline after each sentence
|
||||
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
||||
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
||||
return preds, labels
|
||||
|
||||
def _post_process_boolq(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
# preds = [" ".join(pred) for pred in decoded_preds]
|
||||
preds = decoded_preds
|
||||
outputs = []
|
||||
# for pred in preds:
|
||||
# if 'True' == pred:
|
||||
# outputs.append(1)
|
||||
# elif 'False' == pred:
|
||||
# outputs.append(0)
|
||||
# else:
|
||||
# outputs.append(0)
|
||||
# references = [1 if ex['answer'] else 0 for ex in examples]
|
||||
|
||||
references = []
|
||||
for pred, ref in zip(preds, examples):
|
||||
ans = ref['answer']
|
||||
references.append(1 if ans else 0)
|
||||
if pred.lower() in ['true', 'false']:
|
||||
outputs.append(1 if pred.lower() == 'true' else 0)
|
||||
else:
|
||||
# generate a wrong prediction if pred is not true/false
|
||||
outputs.append(0 if ans else 1)
|
||||
|
||||
assert(len(references)==len(outputs))
|
||||
formatted_predictions = []
|
||||
return EvalPrediction(predictions=outputs, label_ids=references)
|
||||
|
||||
def _post_process_narrative_qa(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
# preds = [" ".join(pred) for pred in decoded_preds]
|
||||
preds = decoded_preds
|
||||
references = [exp['answers'][0]['text'] for exp in examples]
|
||||
formatted_predictions = []
|
||||
preds, references = self.postprocess_text(preds,references)
|
||||
assert(len(preds)==len(references))
|
||||
return EvalPrediction(predictions=preds, label_ids=references)
|
||||
|
||||
def _post_process_drop(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
# preds = [" ".join(pred) for pred in decoded_preds]
|
||||
preds = decoded_preds
|
||||
references = [exp['answers'][0]['text'] for exp in examples]
|
||||
formatted_predictions = []
|
||||
preds, references = self.postprocess_text(preds, references)
|
||||
assert (len(preds) == len(references))
|
||||
return EvalPrediction(predictions=preds, label_ids=references)
|
||||
|
||||
def _post_process_race(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
references = [{"answer": ex['answer'], 'options':ex['options']} for ex in examples]
|
||||
assert(len(references)==len(decoded_preds))
|
||||
gold_ids , pred_ids = [],[]
|
||||
for prediction, reference in zip(decoded_preds, references):
|
||||
# reference = json.loads(reference)
|
||||
gold = int(ord(reference['answer'].strip()) - ord('A'))
|
||||
options = reference['options']
|
||||
prediction = prediction.replace("\n", "").strip()
|
||||
options = [opt.strip() for opt in options if len(opt) > 0]
|
||||
# print('options',options,type(options))
|
||||
# print('prediction',prediction)
|
||||
# print('answer',gold)
|
||||
scores = [score_string_similarity(opt, prediction) for opt in options]
|
||||
max_idx = np.argmax(scores)
|
||||
gold_ids.append(gold)
|
||||
pred_ids.append(max_idx)
|
||||
# selected_ans = chr(ord('A') + max_idx)
|
||||
# print(len(references),len(decoded_preds))
|
||||
|
||||
return EvalPrediction(predictions=pred_ids,label_ids = gold_ids)
|
||||
|
||||
def _post_process_openbookqa(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
references = [{"answer": ex['answerKey'], 'options': ex['choices']['text']} for ex in examples]
|
||||
assert len(references) == len(decoded_preds)
|
||||
gold_ids, pred_ids = [], []
|
||||
for prediction, reference in zip(decoded_preds, references):
|
||||
# reference = json.loads(reference)
|
||||
gold = int(ord(reference['answer'].strip()) - ord('A'))
|
||||
options = reference['options']
|
||||
prediction = prediction.replace("\n", "").strip()
|
||||
options = [opt.strip() for opt in options if len(opt) > 0]
|
||||
scores = [score_string_similarity(opt, prediction) for opt in options]
|
||||
max_idx = np.argmax(scores)
|
||||
gold_ids.append(gold)
|
||||
pred_ids.append(max_idx)
|
||||
return EvalPrediction(predictions=pred_ids,label_ids = gold_ids)
|
||||
|
||||
def _post_process_dream(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
references = [{"answer": ex['answer'], 'options': ex['choice']} for ex in examples]
|
||||
assert(len(references)==len(decoded_preds))
|
||||
gold_ids, pred_ids = [],[]
|
||||
for prediction, reference in zip(decoded_preds, references):
|
||||
gold = reference['options'].index(reference['answer'])
|
||||
options = reference['options']
|
||||
prediction = prediction.replace("\n", "").strip()
|
||||
options = [opt.strip() for opt in options if len(opt) > 0]
|
||||
scores = [score_string_similarity(opt, prediction) for opt in options]
|
||||
max_idx = np.argmax(scores)
|
||||
gold_ids.append(gold)
|
||||
pred_ids.append(max_idx)
|
||||
return EvalPrediction(predictions=pred_ids, label_ids = gold_ids)
|
||||
|
||||
def evaluate(self, eval_dataset=None, eval_examples=None, tokenizer=None,ignore_keys=None, metric_key_prefix: str = "eval",
|
||||
max_length: Optional[int] = None,num_beams: Optional[int] = None):
|
||||
self._memory_tracker.start()
|
||||
self._max_length = max_length if max_length is not None else self.args.generation_max_length
|
||||
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
|
||||
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
|
||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||
eval_examples = self.eval_examples if eval_examples is None else eval_examples
|
||||
|
||||
# Temporarily disable metric computation, we will do it in the loop here.
|
||||
compute_metrics = self.compute_metrics
|
||||
self.compute_metrics = None
|
||||
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
|
||||
try:
|
||||
output = eval_loop(
|
||||
eval_dataloader,
|
||||
description="Evaluation",
|
||||
# No point gathering the predictions if there are no metrics, otherwise we defer to
|
||||
# self.args.prediction_loss_only
|
||||
prediction_loss_only=True if compute_metrics is None else None,
|
||||
ignore_keys=ignore_keys,
|
||||
)
|
||||
finally:
|
||||
self.compute_metrics = compute_metrics
|
||||
|
||||
if self.post_process_function is not None and self.compute_metrics is not None:
|
||||
eval_preds = self.post_process_function(eval_examples, eval_dataset, output,tokenizer)
|
||||
metrics = self.compute_metrics(eval_preds)
|
||||
if self.dataset_name=='narrativeqa':
|
||||
metrics = {key: value * 100 for key, value in metrics.items()}
|
||||
metrics = {k: round(v, 4) for k, v in metrics.items()}
|
||||
# Prefix all keys with metric_key_prefix + '_'
|
||||
for key in list(metrics.keys()):
|
||||
if not key.startswith(f"{metric_key_prefix}_"):
|
||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||
|
||||
print(metrics)
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||
self._memory_tracker.stop_and_update_metrics(metrics)
|
||||
return metrics
|
||||
|
||||
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||
# want to save except FullyShardedDDP.
|
||||
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
|
||||
|
||||
# Save model checkpoint
|
||||
save_only_best = True
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
|
||||
if self.hp_search_backend is not None and trial is not None:
|
||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||
run_id = trial.number
|
||||
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||
from ray import tune
|
||||
|
||||
run_id = tune.get_trial_id()
|
||||
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
|
||||
run_id = trial.id
|
||||
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
|
||||
run_dir = os.path.join(self.args.output_dir, run_name)
|
||||
else:
|
||||
run_dir = self.args.output_dir
|
||||
self.store_flos()
|
||||
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
if not save_only_best:
|
||||
self.save_model(output_dir)
|
||||
if self.deepspeed:
|
||||
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
|
||||
# config `stage3_gather_fp16_weights_on_model_save` is True
|
||||
self.deepspeed.save_checkpoint(output_dir)
|
||||
|
||||
# Save optimizer and scheduler
|
||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||
self.optimizer.consolidate_state_dict()
|
||||
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous("saving_optimizer_states")
|
||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
elif is_sagemaker_mp_enabled():
|
||||
if smp.dp_rank() == 0:
|
||||
# Consolidate the state dict on all processed of dp_rank 0
|
||||
opt_state_dict = self.optimizer.state_dict()
|
||||
# Save it and the scheduler on the main process
|
||||
if self.args.should_save and not save_only_best:
|
||||
torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.do_grad_scaling:
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
elif self.args.should_save and not self.deepspeed and not save_only_best:
|
||||
# deepspeed.save_checkpoint above saves model/optim/sched
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
self.do_grad_scaling=True
|
||||
# if self.do_grad_scaling:
|
||||
# torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
|
||||
# Determine the new best metric / best model checkpoint
|
||||
if metrics is not None and self.args.metric_for_best_model is not None:
|
||||
metric_to_check = self.args.metric_for_best_model
|
||||
if not metric_to_check.startswith("eval_"):
|
||||
metric_to_check = f"eval_{metric_to_check}"
|
||||
metric_value = metrics[metric_to_check]
|
||||
|
||||
operator = np.greater if self.args.greater_is_better else np.less
|
||||
if (
|
||||
self.state.best_metric is None
|
||||
or self.state.best_model_checkpoint is None
|
||||
or operator(metric_value, self.state.best_metric)
|
||||
):
|
||||
self.state.best_metric = metric_value
|
||||
self.state.best_model_checkpoint = output_dir
|
||||
best_dir = os.path.join(run_dir,'best-checkpoint')
|
||||
# if not os.path.exists(best_dir):
|
||||
# os.mkdir(best_dir,exist_ok=True)
|
||||
if self.args.should_save:
|
||||
self.save_model(best_dir)
|
||||
self.state.save_to_json(os.path.join(best_dir, TRAINER_STATE_NAME))
|
||||
|
||||
|
||||
# Save the Trainer state
|
||||
if self.args.should_save and not save_only_best:
|
||||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||
|
||||
# Save RNG state in non-distributed training
|
||||
rng_states = {
|
||||
"python": random.getstate(),
|
||||
"numpy": np.random.get_state(),
|
||||
"cpu": torch.random.get_rng_state(),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
if self.args.local_rank == -1:
|
||||
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
|
||||
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
|
||||
else:
|
||||
rng_states["cuda"] = torch.cuda.random.get_rng_state()
|
||||
|
||||
if is_torch_tpu_available():
|
||||
rng_states["xla"] = xm.get_rng_state()
|
||||
|
||||
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
|
||||
# not yet exist.
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
|
||||
if local_rank == -1:
|
||||
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
|
||||
else:
|
||||
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
|
||||
|
||||
if self.args.push_to_hub:
|
||||
self._push_from_checkpoint(output_dir)
|
||||
|
||||
# Maybe delete some older checkpoints.
|
||||
if self.args.should_save and not save_only_best:
|
||||
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
||||
|
||||
|
||||
|
||||
|
||||
# def predict(self, predict_dataset=None, predict_examples=None, tokenizer=None,ignore_keys=None, metric_key_prefix: str = "predict",
|
||||
# max_length: Optional[int] = None,num_beams: Optional[int] = None):
|
||||
# self._memory_tracker.start()
|
||||
# self._max_length = max_length if max_length is not None else self.args.generation_max_length
|
||||
# self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
|
||||
# predict_dataset = self.predict_dataset if predict_dataset is None else predict_dataset
|
||||
# predict_dataloader = self.get_eval_dataloader(predict_dataset)
|
||||
# predict_examples = predict_examples
|
||||
#
|
||||
# # Temporarily disable metric computation, we will do it in the loop here.
|
||||
# compute_metrics = self.compute_metrics
|
||||
# self.compute_metrics = None
|
||||
# eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
|
||||
# try:
|
||||
# output = eval_loop(
|
||||
# predict_dataloader,
|
||||
# description="Predict",
|
||||
# # No point gathering the predictions if there are no metrics, otherwise we defer to
|
||||
# # self.args.prediction_loss_only
|
||||
# prediction_loss_only=True if compute_metrics is None else None,
|
||||
# ignore_keys=ignore_keys,
|
||||
# )
|
||||
# finally:
|
||||
# self.compute_metrics = compute_metrics
|
||||
#
|
||||
# if self.post_process_function is not None and self.compute_metrics is not None:
|
||||
# predict_preds = self.post_process_function(predict_examples, predict_dataset, output,tokenizer)
|
||||
# metrics = self.compute_metrics(predict_preds)
|
||||
# if self.dataset_name=='narrativeqa':
|
||||
# metrics = {key: value.mid.fmeasure * 100 for key, value in metrics.items()}
|
||||
# metrics = {k: round(v, 4) for k, v in metrics.items()}
|
||||
# # Prefix all keys with metric_key_prefix + '_'
|
||||
# for key in list(metrics.keys()):
|
||||
# if not key.startswith(f"{metric_key_prefix}_"):
|
||||
# metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||
#
|
||||
# self.log(metrics)
|
||||
# else:
|
||||
# metrics = {}
|
||||
#
|
||||
# self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||
# self._memory_tracker.stop_and_update_metrics(metrics)
|
||||
# return metrics
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Accuracy metric."""
|
||||
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
_DESCRIPTION = """
|
||||
Accuracy is the proportion of correct predictions among the total number of cases processed. It can be computed with:
|
||||
Accuracy = (TP + TN) / (TP + TN + FP + FN)
|
||||
Where:
|
||||
TP: True positive
|
||||
TN: True negative
|
||||
FP: False positive
|
||||
FN: False negative
|
||||
"""
|
||||
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Args:
|
||||
predictions (`list` of `int`): Predicted labels.
|
||||
references (`list` of `int`): Ground truth labels.
|
||||
normalize (`boolean`): If set to False, returns the number of correctly classified samples. Otherwise, returns the fraction of correctly classified samples. Defaults to True.
|
||||
sample_weight (`list` of `float`): Sample weights Defaults to None.
|
||||
|
||||
Returns:
|
||||
accuracy (`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher score means higher accuracy.
|
||||
|
||||
Examples:
|
||||
|
||||
Example 1-A simple example
|
||||
>>> accuracy_metric = datasets.load_metric("accuracy")
|
||||
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0])
|
||||
>>> print(results)
|
||||
{'accuracy': 0.5}
|
||||
|
||||
Example 2-The same as Example 1, except with `normalize` set to `False`.
|
||||
>>> accuracy_metric = datasets.load_metric("accuracy")
|
||||
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], normalize=False)
|
||||
>>> print(results)
|
||||
{'accuracy': 3.0}
|
||||
|
||||
Example 3-The same as Example 1, except with `sample_weight` set.
|
||||
>>> accuracy_metric = datasets.load_metric("accuracy")
|
||||
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], sample_weight=[0.5, 2, 0.7, 0.5, 9, 0.4])
|
||||
>>> print(results)
|
||||
{'accuracy': 0.8778625954198473}
|
||||
"""
|
||||
|
||||
|
||||
_CITATION = """
|
||||
@article{scikit-learn,
|
||||
title={Scikit-learn: Machine Learning in {P}ython},
|
||||
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
|
||||
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
|
||||
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
|
||||
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
|
||||
journal={Journal of Machine Learning Research},
|
||||
volume={12},
|
||||
pages={2825--2830},
|
||||
year={2011}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class Accuracy(datasets.Metric):
|
||||
def _info(self):
|
||||
return datasets.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"predictions": datasets.Sequence(datasets.Value("int32")),
|
||||
"references": datasets.Sequence(datasets.Value("int32")),
|
||||
}
|
||||
if self.config_name == "multilabel"
|
||||
else {
|
||||
"predictions": datasets.Value("int32"),
|
||||
"references": datasets.Value("int32"),
|
||||
}
|
||||
),
|
||||
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
|
||||
)
|
||||
|
||||
def _compute(self, predictions, references, normalize=True, sample_weight=None):
|
||||
return {
|
||||
"accuracy": float(
|
||||
accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight)
|
||||
)
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2020 The HuggingFace Datasets Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" BLEU metric. """
|
||||
|
||||
import datasets
|
||||
|
||||
from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@INPROCEEDINGS{Papineni02bleu:a,
|
||||
author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu},
|
||||
title = {BLEU: a Method for Automatic Evaluation of Machine Translation},
|
||||
booktitle = {},
|
||||
year = {2002},
|
||||
pages = {311--318}
|
||||
}
|
||||
@inproceedings{lin-och-2004-orange,
|
||||
title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation",
|
||||
author = "Lin, Chin-Yew and
|
||||
Och, Franz Josef",
|
||||
booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics",
|
||||
month = "aug 23{--}aug 27",
|
||||
year = "2004",
|
||||
address = "Geneva, Switzerland",
|
||||
publisher = "COLING",
|
||||
url = "https://www.aclweb.org/anthology/C04-1072",
|
||||
pages = "501--507",
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another.
|
||||
Quality is considered to be the correspondence between a machine's output and that of a human: "the closer a machine translation is to a professional human translation,
|
||||
the better it is" – this is the central idea behind BLEU. BLEU was one of the first metrics to claim a high correlation with human judgements of quality, and
|
||||
remains one of the most popular automated and inexpensive metrics.
|
||||
|
||||
Scores are calculated for individual translated segments—generally sentences—by comparing them with a set of good quality reference translations.
|
||||
Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. Intelligibility or grammatical correctness
|
||||
are not taken into account[citation needed].
|
||||
|
||||
BLEU's output is always a number between 0 and 1. This value indicates how similar the candidate text is to the reference texts, with values closer to 1
|
||||
representing more similar texts. Few human translations will attain a score of 1, since this would indicate that the candidate is identical to one of the
|
||||
reference translations. For this reason, it is not necessary to attain a score of 1. Because there are more opportunities to match, adding additional
|
||||
reference translations will increase the BLEU score.
|
||||
"""
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Computes BLEU score of translated segments against one or more references.
|
||||
Args:
|
||||
predictions: list of translations to score.
|
||||
Each translation should be tokenized into a list of tokens.
|
||||
references: list of lists of references for each translation.
|
||||
Each reference should be tokenized into a list of tokens.
|
||||
max_order: Maximum n-gram order to use when computing BLEU score.
|
||||
smooth: Whether or not to apply Lin et al. 2004 smoothing.
|
||||
Returns:
|
||||
'bleu': bleu score,
|
||||
'precisions': geometric mean of n-gram precisions,
|
||||
'brevity_penalty': brevity penalty,
|
||||
'length_ratio': ratio of lengths,
|
||||
'translation_length': translation_length,
|
||||
'reference_length': reference_length
|
||||
Examples:
|
||||
|
||||
>>> predictions = [
|
||||
... ["hello", "there", "general", "kenobi"], # tokenized prediction of the first sample
|
||||
... ["foo", "bar", "foobar"] # tokenized prediction of the second sample
|
||||
... ]
|
||||
>>> references = [
|
||||
... [["hello", "there", "general", "kenobi"], ["hello", "there", "!"]], # tokenized references for the first sample (2 references)
|
||||
... [["foo", "bar", "foobar"]] # tokenized references for the second sample (1 reference)
|
||||
... ]
|
||||
>>> bleu = datasets.load_metric("bleu")
|
||||
>>> results = bleu.compute(predictions=predictions, references=references)
|
||||
>>> print(results["bleu"])
|
||||
1.0
|
||||
"""
|
||||
|
||||
|
||||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class Bleu(datasets.Metric):
|
||||
def _info(self):
|
||||
return datasets.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"predictions": datasets.Sequence(datasets.Value("string", id="token"), id="sequence"),
|
||||
"references": datasets.Sequence(
|
||||
datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), id="references"
|
||||
),
|
||||
}
|
||||
),
|
||||
codebase_urls=["https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py"],
|
||||
reference_urls=[
|
||||
"https://en.wikipedia.org/wiki/BLEU",
|
||||
"https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213",
|
||||
],
|
||||
)
|
||||
|
||||
def _compute(self, predictions, references, max_order=4, smooth=False):
|
||||
score = compute_bleu(
|
||||
reference_corpus=references, translation_corpus=predictions, max_order=max_order, smooth=smooth
|
||||
)
|
||||
(bleu, precisions, bp, ratio, translation_length, reference_length) = score
|
||||
return {
|
||||
"bleu": bleu,
|
||||
"precisions": precisions,
|
||||
"brevity_penalty": bp,
|
||||
"length_ratio": ratio,
|
||||
"translation_length": translation_length,
|
||||
"reference_length": reference_length,
|
||||
}
|
|
@ -0,0 +1,219 @@
|
|||
# Copyright 2020 The HuggingFace Datasets Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" BLEU metric. """
|
||||
|
||||
import datasets
|
||||
|
||||
#from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@INPROCEEDINGS{Papineni02bleu:a,
|
||||
author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu},
|
||||
title = {BLEU: a Method for Automatic Evaluation of Machine Translation},
|
||||
booktitle = {},
|
||||
year = {2002},
|
||||
pages = {311--318}
|
||||
}
|
||||
@inproceedings{lin-och-2004-orange,
|
||||
title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation",
|
||||
author = "Lin, Chin-Yew and
|
||||
Och, Franz Josef",
|
||||
booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics",
|
||||
month = "aug 23{--}aug 27",
|
||||
year = "2004",
|
||||
address = "Geneva, Switzerland",
|
||||
publisher = "COLING",
|
||||
url = "https://www.aclweb.org/anthology/C04-1072",
|
||||
pages = "501--507",
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another.
|
||||
Quality is considered to be the correspondence between a machine's output and that of a human: "the closer a machine translation is to a professional human translation,
|
||||
the better it is" – this is the central idea behind BLEU. BLEU was one of the first metrics to claim a high correlation with human judgements of quality, and
|
||||
remains one of the most popular automated and inexpensive metrics.
|
||||
|
||||
Scores are calculated for individual translated segments—generally sentences—by comparing them with a set of good quality reference translations.
|
||||
Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. Intelligibility or grammatical correctness
|
||||
are not taken into account[citation needed].
|
||||
|
||||
BLEU's output is always a number between 0 and 1. This value indicates how similar the candidate text is to the reference texts, with values closer to 1
|
||||
representing more similar texts. Few human translations will attain a score of 1, since this would indicate that the candidate is identical to one of the
|
||||
reference translations. For this reason, it is not necessary to attain a score of 1. Because there are more opportunities to match, adding additional
|
||||
reference translations will increase the BLEU score.
|
||||
"""
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Computes BLEU score of translated segments against one or more references.
|
||||
Args:
|
||||
predictions: list of translations to score.
|
||||
Each translation should be tokenized into a list of tokens.
|
||||
references: list of lists of references for each translation.
|
||||
Each reference should be tokenized into a list of tokens.
|
||||
max_order: Maximum n-gram order to use when computing BLEU score.
|
||||
smooth: Whether or not to apply Lin et al. 2004 smoothing.
|
||||
Returns:
|
||||
'bleu': bleu score,
|
||||
'precisions': geometric mean of n-gram precisions,
|
||||
'brevity_penalty': brevity penalty,
|
||||
'length_ratio': ratio of lengths,
|
||||
'translation_length': translation_length,
|
||||
'reference_length': reference_length
|
||||
Examples:
|
||||
|
||||
>>> predictions = [
|
||||
... ["hello", "there", "general", "kenobi"], # tokenized prediction of the first sample
|
||||
... ["foo", "bar", "foobar"] # tokenized prediction of the second sample
|
||||
... ]
|
||||
>>> references = [
|
||||
... [["hello", "there", "general", "kenobi"], ["hello", "there", "!"]], # tokenized references for the first sample (2 references)
|
||||
... [["foo", "bar", "foobar"]] # tokenized references for the second sample (1 reference)
|
||||
... ]
|
||||
>>> bleu = datasets.load_metric("bleu")
|
||||
>>> results = bleu.compute(predictions=predictions, references=references)
|
||||
>>> print(results["bleu"])
|
||||
1.0
|
||||
"""
|
||||
|
||||
import collections
|
||||
import math
|
||||
|
||||
|
||||
def _get_ngrams(segment, max_order):
|
||||
"""Extracts all n-grams upto a given maximum order from an input segment.
|
||||
Args:
|
||||
segment: text segment from which n-grams will be extracted.
|
||||
max_order: maximum length in tokens of the n-grams returned by this
|
||||
methods.
|
||||
Returns:
|
||||
The Counter containing all n-grams upto max_order in segment
|
||||
with a count of how many times each n-gram occurred.
|
||||
"""
|
||||
ngram_counts = collections.Counter()
|
||||
for order in range(1, max_order + 1):
|
||||
for i in range(0, len(segment) - order + 1):
|
||||
ngram = tuple(segment[i:i+order])
|
||||
ngram_counts[ngram] += 1
|
||||
return ngram_counts
|
||||
|
||||
|
||||
def compute_bleu(reference_corpus, translation_corpus, max_order=1,
|
||||
smooth=False):
|
||||
"""Computes BLEU score of translated segments against one or more references.
|
||||
Args:
|
||||
reference_corpus: list of lists of references for each translation. Each
|
||||
reference should be tokenized into a list of tokens.
|
||||
translation_corpus: list of translations to score. Each translation
|
||||
should be tokenized into a list of tokens.
|
||||
max_order: Maximum n-gram order to use when computing BLEU score.
|
||||
smooth: Whether or not to apply Lin et al. 2004 smoothing.
|
||||
Returns:
|
||||
3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
|
||||
precisions and brevity penalty.
|
||||
"""
|
||||
matches_by_order = [0] * max_order
|
||||
possible_matches_by_order = [0] * max_order
|
||||
reference_length = 0
|
||||
translation_length = 0
|
||||
for (references, translation) in zip(reference_corpus,
|
||||
translation_corpus):
|
||||
# translation = references[0]
|
||||
translation = [item.lower() for item in translation]
|
||||
|
||||
reference_length += min(len(r) for r in references)
|
||||
translation_length += len(translation)
|
||||
|
||||
merged_ref_ngram_counts = collections.Counter()
|
||||
|
||||
for reference in references:
|
||||
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
|
||||
# print(merged_ref_ngram_counts)
|
||||
translation_ngram_counts = _get_ngrams(translation, max_order)
|
||||
overlap = translation_ngram_counts & merged_ref_ngram_counts
|
||||
for ngram in overlap:
|
||||
matches_by_order[len(ngram)-1] += overlap[ngram]
|
||||
for order in range(1, max_order+1):
|
||||
possible_matches = len(translation) - order + 1
|
||||
if possible_matches > 0:
|
||||
possible_matches_by_order[order-1] += possible_matches
|
||||
|
||||
|
||||
precisions = [0] * max_order
|
||||
for i in range(0, max_order):
|
||||
if smooth:
|
||||
precisions[i] = ((matches_by_order[i] + 1.) /
|
||||
(possible_matches_by_order[i] + 1.))
|
||||
else:
|
||||
if possible_matches_by_order[i] > 0:
|
||||
precisions[i] = (float(matches_by_order[i]) /
|
||||
possible_matches_by_order[i])
|
||||
else:
|
||||
precisions[i] = 0.0
|
||||
|
||||
|
||||
if min(precisions) > 0:
|
||||
p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
|
||||
geo_mean = math.exp(p_log_sum)
|
||||
else:
|
||||
geo_mean = 0
|
||||
|
||||
ratio = float(translation_length) / reference_length
|
||||
|
||||
if ratio > 1.0:
|
||||
bp = 1.
|
||||
else:
|
||||
bp = math.exp(1 - 1. / ratio)
|
||||
|
||||
bleu = geo_mean * bp
|
||||
|
||||
|
||||
return (bleu, precisions, bp, ratio, translation_length, reference_length)
|
||||
|
||||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class Bleu(datasets.Metric):
|
||||
def _info(self):
|
||||
return datasets.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"predictions": datasets.Sequence(datasets.Value("string", id="token"), id="sequence"),
|
||||
"references": datasets.Sequence(
|
||||
datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), id="references"
|
||||
),
|
||||
}
|
||||
),
|
||||
codebase_urls=["https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py"],
|
||||
reference_urls=[
|
||||
"https://en.wikipedia.org/wiki/BLEU",
|
||||
"https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213",
|
||||
],
|
||||
)
|
||||
|
||||
def _compute(self, predictions, references, max_order=4, smooth=False):
|
||||
score = compute_bleu(
|
||||
reference_corpus=references, translation_corpus=predictions, max_order=1, smooth=smooth
|
||||
)
|
||||
(bleu, precisions, bp, ratio, translation_length, reference_length) = score
|
||||
return {
|
||||
"bleu": bleu,
|
||||
"precisions": precisions,
|
||||
"brevity_penalty": bp,
|
||||
"length_ratio": ratio,
|
||||
"translation_length": translation_length,
|
||||
"reference_length": reference_length,
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright 2020 The HuggingFace Datasets Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" ROUGE metric from Google Research github repo. """
|
||||
|
||||
# The dependencies in https://github.com/google-research/google-research/blob/master/rouge/requirements.txt
|
||||
import absl # Here to have a nice missing dependency error message early on
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy # Here to have a nice missing dependency error message early on
|
||||
import six # Here to have a nice missing dependency error message early on
|
||||
from .rouge_score import RougeScorer
|
||||
from .scoring import *
|
||||
from .rouge_tokenizers import *
|
||||
import datasets
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@inproceedings{lin-2004-rouge,
|
||||
title = "{ROUGE}: A Package for Automatic Evaluation of Summaries",
|
||||
author = "Lin, Chin-Yew",
|
||||
booktitle = "Text Summarization Branches Out",
|
||||
month = jul,
|
||||
year = "2004",
|
||||
address = "Barcelona, Spain",
|
||||
publisher = "Association for Computational Linguistics",
|
||||
url = "https://www.aclweb.org/anthology/W04-1013",
|
||||
pages = "74--81",
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
ROUGE, or Recall-Oriented Understudy for Gisting Evaluation, is a set of metrics and a software package used for
|
||||
evaluating automatic summarization and machine translation software in natural language processing.
|
||||
The metrics compare an automatically produced summary or translation against a reference or a set of references (human-produced) summary or translation.
|
||||
|
||||
Note that ROUGE is case insensitive, meaning that upper case letters are treated the same way as lower case letters.
|
||||
|
||||
This metrics is a wrapper around Google Research reimplementation of ROUGE:
|
||||
https://github.com/google-research/google-research/tree/master/rouge
|
||||
"""
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Calculates average rouge scores for a list of hypotheses and references
|
||||
Args:
|
||||
predictions: list of predictions to score. Each prediction
|
||||
should be a string with tokens separated by spaces.
|
||||
references: list of reference for each prediction. Each
|
||||
reference should be a string with tokens separated by spaces.
|
||||
rouge_types: A list of rouge types to calculate.
|
||||
Valid names:
|
||||
`"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
|
||||
`"rougeL"`: Longest common subsequence based scoring.
|
||||
`"rougeLSum"`: rougeLsum splits text using `"\n"`.
|
||||
See details in https://github.com/huggingface/datasets/issues/617
|
||||
use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
|
||||
use_aggregator: Return aggregates if this is set to True
|
||||
Returns:
|
||||
rouge1: rouge_1 (precision, recall, f1),
|
||||
rouge2: rouge_2 (precision, recall, f1),
|
||||
rougeL: rouge_l (precision, recall, f1),
|
||||
rougeLsum: rouge_lsum (precision, recall, f1)
|
||||
Examples:
|
||||
|
||||
>>> rouge = datasets.load_metric('rouge')
|
||||
>>> predictions = ["hello there", "general kenobi"]
|
||||
>>> references = ["hello there", "general kenobi"]
|
||||
>>> results = rouge.compute(predictions=predictions, references=references)
|
||||
>>> print(list(results.keys()))
|
||||
['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
|
||||
>>> print(results["rouge1"])
|
||||
AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))
|
||||
>>> print(results["rouge1"].mid.fmeasure)
|
||||
1.0
|
||||
"""
|
||||
|
||||
|
||||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class Rouge(datasets.Metric):
|
||||
def _info(self):
|
||||
return datasets.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"predictions": datasets.Value("string", id="sequence"),
|
||||
"references": datasets.Value("string", id="sequence"),
|
||||
}
|
||||
),
|
||||
codebase_urls=["https://github.com/google-research/google-research/tree/master/rouge"],
|
||||
reference_urls=[
|
||||
"https://en.wikipedia.org/wiki/ROUGE_(metric)",
|
||||
"https://github.com/google-research/google-research/tree/master/rouge",
|
||||
],
|
||||
)
|
||||
|
||||
def _compute(self, predictions, references, rouge_types=None, use_aggregator=True, use_stemmer=False):
|
||||
if rouge_types is None:
|
||||
rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
|
||||
|
||||
scorer = RougeScorer(rouge_types=rouge_types, use_stemmer=use_stemmer)
|
||||
if use_aggregator:
|
||||
aggregator = BootstrapAggregator()
|
||||
else:
|
||||
scores = []
|
||||
|
||||
for ref, pred in zip(references, predictions):
|
||||
score = scorer.score(ref, pred)
|
||||
if use_aggregator:
|
||||
aggregator.add_scores(score)
|
||||
else:
|
||||
scores.append(score)
|
||||
|
||||
if use_aggregator:
|
||||
result = aggregator.aggregate()
|
||||
for key in rouge_types:
|
||||
result[key] = result[key].mid.fmeasure
|
||||
else:
|
||||
result = {}
|
||||
for key in scores[0]:
|
||||
result[key] = list(score[key] for score in scores)
|
||||
|
||||
return result
|
|
@ -0,0 +1,318 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Computes rouge scores between two text blobs.
|
||||
Implementation replicates the functionality in the original ROUGE package. See:
|
||||
Lin, Chin-Yew. ROUGE: a Package for Automatic Evaluation of Summaries. In
|
||||
Proceedings of the Workshop on Text Summarization Branches Out (WAS 2004),
|
||||
Barcelona, Spain, July 25 - 26, 2004.
|
||||
Default options are equivalent to running:
|
||||
ROUGE-1.5.5.pl -e data -n 2 -a settings.xml
|
||||
Or with use_stemmer=True:
|
||||
ROUGE-1.5.5.pl -m -e data -n 2 -a settings.xml
|
||||
In these examples settings.xml lists input files and formats.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
|
||||
from absl import logging
|
||||
import nltk
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import map
|
||||
from six.moves import range
|
||||
from .scoring import *
|
||||
from .rouge_tokenizers import *
|
||||
|
||||
|
||||
|
||||
class RougeScorer(BaseScorer):
|
||||
"""Calculate rouges scores between two blobs of text.
|
||||
Sample usage:
|
||||
scorer = RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
|
||||
scores = scorer.score('The quick brown fox jumps over the lazy dog',
|
||||
'The quick brown dog jumps on the log.')
|
||||
"""
|
||||
|
||||
def __init__(self, rouge_types, use_stemmer=False, split_summaries=False,
|
||||
tokenizer=None):
|
||||
"""Initializes a new RougeScorer.
|
||||
Valid rouge types that can be computed are:
|
||||
rougen (e.g. rouge1, rouge2): n-gram based scoring.
|
||||
rougeL: Longest common subsequence based scoring.
|
||||
Args:
|
||||
rouge_types: A list of rouge types to calculate.
|
||||
use_stemmer: Bool indicating whether Porter stemmer should be used to
|
||||
strip word suffixes to improve matching. This arg is used in the
|
||||
DefaultTokenizer, but other tokenizers might or might not choose to
|
||||
use this.
|
||||
split_summaries: whether to add newlines between sentences for rougeLsum
|
||||
tokenizer: Tokenizer object which has a tokenize() method.
|
||||
Returns:
|
||||
A dict mapping rouge types to Score tuples.
|
||||
"""
|
||||
|
||||
self.rouge_types = rouge_types
|
||||
if tokenizer:
|
||||
self._tokenizer = tokenizer
|
||||
else:
|
||||
self._tokenizer = DefaultTokenizer(use_stemmer)
|
||||
logging.info("Using default tokenizer.")
|
||||
|
||||
self._split_summaries = split_summaries
|
||||
|
||||
def score_multi(self, targets, prediction):
|
||||
"""Calculates rouge scores between targets and prediction.
|
||||
The target with the maximum f-measure is used for the final score for
|
||||
each score type..
|
||||
Args:
|
||||
targets: list of texts containing the targets
|
||||
prediction: Text containing the predicted text.
|
||||
Returns:
|
||||
A dict mapping each rouge type to a Score object.
|
||||
Raises:
|
||||
ValueError: If an invalid rouge type is encountered.
|
||||
"""
|
||||
score_dicts = [self.score(t, prediction) for t in targets]
|
||||
max_score = {}
|
||||
for k in self.rouge_types:
|
||||
index = np.argmax([s[k].fmeasure for s in score_dicts])
|
||||
max_score[k] = score_dicts[index][k]
|
||||
|
||||
return max_score
|
||||
|
||||
def score(self, target, prediction):
|
||||
"""Calculates rouge scores between the target and prediction.
|
||||
Args:
|
||||
target: Text containing the target (ground truth) text,
|
||||
or if a list
|
||||
prediction: Text containing the predicted text.
|
||||
Returns:
|
||||
A dict mapping each rouge type to a Score object.
|
||||
Raises:
|
||||
ValueError: If an invalid rouge type is encountered.
|
||||
"""
|
||||
|
||||
# Pre-compute target tokens and prediction tokens for use by different
|
||||
# types, except if only "rougeLsum" is requested.
|
||||
if len(self.rouge_types) == 1 and self.rouge_types[0] == "rougeLsum":
|
||||
target_tokens = None
|
||||
prediction_tokens = None
|
||||
else:
|
||||
target_tokens = self._tokenizer.tokenize(target)
|
||||
prediction_tokens = self._tokenizer.tokenize(prediction)
|
||||
result = {}
|
||||
|
||||
for rouge_type in self.rouge_types:
|
||||
if rouge_type == "rougeL":
|
||||
# Rouge from longest common subsequences.
|
||||
scores = _score_lcs(target_tokens, prediction_tokens)
|
||||
elif rouge_type == "rougeLsum":
|
||||
# Note: Does not support multi-line text.
|
||||
def get_sents(text):
|
||||
if self._split_summaries:
|
||||
sents = nltk.sent_tokenize(text)
|
||||
else:
|
||||
# Assume sentences are separated by newline.
|
||||
sents = six.ensure_str(text).split("\n")
|
||||
sents = [x for x in sents if len(x)]
|
||||
return sents
|
||||
|
||||
target_tokens_list = [
|
||||
self._tokenizer.tokenize(s) for s in get_sents(target)]
|
||||
prediction_tokens_list = [
|
||||
self._tokenizer.tokenize(s) for s in get_sents(prediction)]
|
||||
|
||||
scores = _summary_level_lcs(target_tokens_list,
|
||||
prediction_tokens_list)
|
||||
elif re.match(r"rouge[0-9]$", six.ensure_str(rouge_type)):
|
||||
# Rouge from n-grams.
|
||||
n = int(rouge_type[5:])
|
||||
if n <= 0:
|
||||
raise ValueError("rougen requires positive n: %s" % rouge_type)
|
||||
target_ngrams = _create_ngrams(target_tokens, n)
|
||||
prediction_ngrams = _create_ngrams(prediction_tokens, n)
|
||||
scores = _score_ngrams(target_ngrams, prediction_ngrams)
|
||||
else:
|
||||
raise ValueError("Invalid rouge type: %s" % rouge_type)
|
||||
result[rouge_type] = scores
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _create_ngrams(tokens, n):
|
||||
"""Creates ngrams from the given list of tokens.
|
||||
Args:
|
||||
tokens: A list of tokens from which ngrams are created.
|
||||
n: Number of tokens to use, e.g. 2 for bigrams.
|
||||
Returns:
|
||||
A dictionary mapping each bigram to the number of occurrences.
|
||||
"""
|
||||
|
||||
ngrams = collections.Counter()
|
||||
for ngram in (tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)):
|
||||
ngrams[ngram] += 1
|
||||
return ngrams
|
||||
|
||||
|
||||
def _score_lcs(target_tokens, prediction_tokens):
|
||||
"""Computes LCS (Longest Common Subsequence) rouge scores.
|
||||
Args:
|
||||
target_tokens: Tokens from the target text.
|
||||
prediction_tokens: Tokens from the predicted text.
|
||||
Returns:
|
||||
A Score object containing computed scores.
|
||||
"""
|
||||
|
||||
if not target_tokens or not prediction_tokens:
|
||||
return Score(precision=0, recall=0, fmeasure=0)
|
||||
|
||||
# Compute length of LCS from the bottom up in a table (DP appproach).
|
||||
lcs_table = _lcs_table(target_tokens, prediction_tokens)
|
||||
lcs_length = lcs_table[-1][-1]
|
||||
|
||||
precision = lcs_length / len(prediction_tokens)
|
||||
recall = lcs_length / len(target_tokens)
|
||||
fmeasure = func_fmeasure(precision, recall)
|
||||
|
||||
return Score(precision=precision, recall=recall, fmeasure=fmeasure)
|
||||
|
||||
|
||||
def _lcs_table(ref, can):
|
||||
"""Create 2-d LCS score table."""
|
||||
rows = len(ref)
|
||||
cols = len(can)
|
||||
lcs_table = [[0] * (cols + 1) for _ in range(rows + 1)]
|
||||
for i in range(1, rows + 1):
|
||||
for j in range(1, cols + 1):
|
||||
if ref[i - 1] == can[j - 1]:
|
||||
lcs_table[i][j] = lcs_table[i - 1][j - 1] + 1
|
||||
else:
|
||||
lcs_table[i][j] = max(lcs_table[i - 1][j], lcs_table[i][j - 1])
|
||||
return lcs_table
|
||||
|
||||
|
||||
def _backtrack_norec(t, ref, can):
|
||||
"""Read out LCS."""
|
||||
i = len(ref)
|
||||
j = len(can)
|
||||
lcs = []
|
||||
while i > 0 and j > 0:
|
||||
if ref[i - 1] == can[j - 1]:
|
||||
lcs.insert(0, i-1)
|
||||
i -= 1
|
||||
j -= 1
|
||||
elif t[i][j - 1] > t[i - 1][j]:
|
||||
j -= 1
|
||||
else:
|
||||
i -= 1
|
||||
return lcs
|
||||
|
||||
|
||||
def _summary_level_lcs(ref_sent, can_sent):
|
||||
"""ROUGE: Summary-level LCS, section 3.2 in ROUGE paper.
|
||||
Args:
|
||||
ref_sent: list of tokenized reference sentences
|
||||
can_sent: list of tokenized candidate sentences
|
||||
Returns:
|
||||
summary level ROUGE score
|
||||
"""
|
||||
if not ref_sent or not can_sent:
|
||||
return Score(precision=0, recall=0, fmeasure=0)
|
||||
|
||||
m = sum(map(len, ref_sent))
|
||||
n = sum(map(len, can_sent))
|
||||
if not n or not m:
|
||||
return Score(precision=0, recall=0, fmeasure=0)
|
||||
|
||||
# get token counts to prevent double counting
|
||||
token_cnts_r = collections.Counter()
|
||||
token_cnts_c = collections.Counter()
|
||||
for s in ref_sent:
|
||||
# s is a list of tokens
|
||||
token_cnts_r.update(s)
|
||||
for s in can_sent:
|
||||
token_cnts_c.update(s)
|
||||
|
||||
hits = 0
|
||||
for r in ref_sent:
|
||||
lcs = _union_lcs(r, can_sent)
|
||||
# Prevent double-counting:
|
||||
# The paper describes just computing hits += len(_union_lcs()),
|
||||
# but the implementation prevents double counting. We also
|
||||
# implement this as in version 1.5.5.
|
||||
for t in lcs:
|
||||
if token_cnts_c[t] > 0 and token_cnts_r[t] > 0:
|
||||
hits += 1
|
||||
token_cnts_c[t] -= 1
|
||||
token_cnts_r[t] -= 1
|
||||
|
||||
recall = hits / m
|
||||
precision = hits / n
|
||||
fmeasure = func_fmeasure(precision, recall)
|
||||
return Score(precision=precision, recall=recall, fmeasure=fmeasure)
|
||||
|
||||
|
||||
def _union_lcs(ref, c_list):
|
||||
"""Find union LCS between a ref sentence and list of candidate sentences.
|
||||
Args:
|
||||
ref: list of tokens
|
||||
c_list: list of list of indices for LCS into reference summary
|
||||
Returns:
|
||||
List of tokens in ref representing union LCS.
|
||||
"""
|
||||
lcs_list = [lcs_ind(ref, c) for c in c_list]
|
||||
return [ref[i] for i in _find_union(lcs_list)]
|
||||
|
||||
|
||||
def _find_union(lcs_list):
|
||||
"""Finds union LCS given a list of LCS."""
|
||||
return sorted(list(set().union(*lcs_list)))
|
||||
|
||||
|
||||
def lcs_ind(ref, can):
|
||||
"""Returns one of the longest lcs."""
|
||||
t = _lcs_table(ref, can)
|
||||
return _backtrack_norec(t, ref, can)
|
||||
|
||||
|
||||
def _score_ngrams(target_ngrams, prediction_ngrams):
|
||||
"""Compute n-gram based rouge scores.
|
||||
Args:
|
||||
target_ngrams: A Counter object mapping each ngram to number of
|
||||
occurrences for the target text.
|
||||
prediction_ngrams: A Counter object mapping each ngram to number of
|
||||
occurrences for the prediction text.
|
||||
Returns:
|
||||
A Score object containing computed scores.
|
||||
"""
|
||||
|
||||
intersection_ngrams_count = 0
|
||||
for ngram in six.iterkeys(target_ngrams):
|
||||
intersection_ngrams_count += min(target_ngrams[ngram],
|
||||
prediction_ngrams[ngram])
|
||||
target_ngrams_count = sum(target_ngrams.values())
|
||||
prediction_ngrams_count = sum(prediction_ngrams.values())
|
||||
|
||||
precision = intersection_ngrams_count / max(prediction_ngrams_count, 1)
|
||||
recall = intersection_ngrams_count / max(target_ngrams_count, 1)
|
||||
fmeasure = func_fmeasure(precision, recall)
|
||||
|
||||
return Score(precision=precision, recall=recall, fmeasure=fmeasure)
|
|
@ -0,0 +1,90 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
"""Library containing Tokenizer definitions.
|
||||
The RougeScorer class can be instantiated with the tokenizers defined here. New
|
||||
tokenizers can be defined by creating a subclass of the Tokenizer abstract class
|
||||
and overriding the tokenize() method.
|
||||
"""
|
||||
import abc
|
||||
from nltk.stem import porter
|
||||
|
||||
|
||||
import re
|
||||
import six
|
||||
|
||||
|
||||
# Pre-compile regexes that are use often
|
||||
NON_ALPHANUM_PATTERN = r"[^a-z0-9]+"
|
||||
NON_ALPHANUM_RE = re.compile(NON_ALPHANUM_PATTERN)
|
||||
SPACES_PATTERN = r"\s+"
|
||||
SPACES_RE = re.compile(SPACES_PATTERN)
|
||||
VALID_TOKEN_PATTERN = r"^[a-z0-9]+$"
|
||||
VALID_TOKEN_RE = re.compile(VALID_TOKEN_PATTERN)
|
||||
|
||||
|
||||
def _tokenize(text, stemmer):
|
||||
"""Tokenize input text into a list of tokens.
|
||||
This approach aims to replicate the approach taken by Chin-Yew Lin in
|
||||
the original ROUGE implementation.
|
||||
Args:
|
||||
text: A text blob to tokenize.
|
||||
stemmer: An optional stemmer.
|
||||
Returns:
|
||||
A list of string tokens extracted from input text.
|
||||
"""
|
||||
|
||||
# Convert everything to lowercase.
|
||||
text = text.lower()
|
||||
# Replace any non-alpha-numeric characters with spaces.
|
||||
text = NON_ALPHANUM_RE.sub(" ", six.ensure_str(text))
|
||||
|
||||
tokens = SPACES_RE.split(text)
|
||||
if stemmer:
|
||||
# Only stem words more than 3 characters long.
|
||||
tokens = [six.ensure_str(stemmer.stem(x)) if len(x) > 3 else x
|
||||
for x in tokens]
|
||||
|
||||
# One final check to drop any empty or invalid tokens.
|
||||
tokens = [x for x in tokens if VALID_TOKEN_RE.match(x)]
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
class Tokenizer(abc.ABC):
|
||||
"""Abstract base class for a tokenizer.
|
||||
Subclasses of Tokenizer must implement the tokenize() method.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def tokenize(self, text):
|
||||
raise NotImplementedError("Tokenizer must override tokenize() method")
|
||||
|
||||
|
||||
class DefaultTokenizer(Tokenizer):
|
||||
"""Default tokenizer which tokenizes on whitespace."""
|
||||
|
||||
def __init__(self, use_stemmer=False):
|
||||
"""Constructor for DefaultTokenizer.
|
||||
Args:
|
||||
use_stemmer: boolean, indicating whether Porter stemmer should be used to
|
||||
strip word suffixes to improve matching.
|
||||
"""
|
||||
self._stemmer = porter.PorterStemmer() if use_stemmer else None
|
||||
|
||||
def tokenize(self, text):
|
||||
return _tokenize(text, self._stemmer)
|
|
@ -0,0 +1,158 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Library for scoring and evaluation of text samples.
|
||||
Aggregation functions use bootstrap resampling to compute confidence intervals
|
||||
as per the original ROUGE perl implementation.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import collections
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import range
|
||||
|
||||
|
||||
class Score(
|
||||
collections.namedtuple("Score", ["precision", "recall", "fmeasure"])):
|
||||
"""Tuple containing precision, recall, and f-measure values."""
|
||||
|
||||
|
||||
class BaseScorer(object, metaclass=abc.ABCMeta):
|
||||
"""Base class for Scorer objects."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def score(self, target, prediction):
|
||||
"""Calculates score between the target and prediction.
|
||||
Args:
|
||||
target: Text containing the target (ground truth) text.
|
||||
prediction: Text containing the predicted text.
|
||||
Returns:
|
||||
A dict mapping each score_type (string) to Score object.
|
||||
"""
|
||||
|
||||
|
||||
class AggregateScore(
|
||||
collections.namedtuple("AggregateScore", ["low", "mid", "high"])):
|
||||
"""Tuple containing confidence intervals for scores."""
|
||||
|
||||
|
||||
class BootstrapAggregator(object):
|
||||
"""Aggregates scores to provide confidence intervals.
|
||||
Sample usage:
|
||||
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'])
|
||||
aggregator = Aggregator()
|
||||
aggregator.add_scores(scorer.score("one two three", "one two"))
|
||||
aggregator.add_scores(scorer.score("one two five six", "seven eight"))
|
||||
result = aggregator.aggregate()
|
||||
print result
|
||||
{'rougeL': AggregateScore(
|
||||
low=Score(precision=0.0, recall=0.0, fmeasure=0.0),
|
||||
mid=Score(precision=0.5, recall=0.33, fmeasure=0.40),
|
||||
high=Score(precision=1.0, recall=0.66, fmeasure=0.80)),
|
||||
'rouge1': AggregateScore(
|
||||
low=Score(precision=0.0, recall=0.0, fmeasure=0.0),
|
||||
mid=Score(precision=0.5, recall=0.33, fmeasure=0.40),
|
||||
high=Score(precision=1.0, recall=0.66, fmeasure=0.80))}
|
||||
"""
|
||||
|
||||
def __init__(self, confidence_interval=0.95, n_samples=1000):
|
||||
"""Initializes a BootstrapAggregator object.
|
||||
Args:
|
||||
confidence_interval: Confidence interval to compute on the mean as a
|
||||
decimal.
|
||||
n_samples: Number of samples to use for bootstrap resampling.
|
||||
Raises:
|
||||
ValueError: If invalid argument is given.
|
||||
"""
|
||||
|
||||
if confidence_interval < 0 or confidence_interval > 1:
|
||||
raise ValueError("confidence_interval must be in range [0, 1]")
|
||||
if n_samples <= 0:
|
||||
raise ValueError("n_samples must be positive")
|
||||
|
||||
self._n_samples = n_samples
|
||||
self._confidence_interval = confidence_interval
|
||||
self._scores = collections.defaultdict(list)
|
||||
|
||||
def add_scores(self, scores):
|
||||
"""Adds a sample for future aggregation.
|
||||
Args:
|
||||
scores: Dict mapping score_type strings to a namedtuple object/class
|
||||
representing a score.
|
||||
"""
|
||||
|
||||
for score_type, score in six.iteritems(scores):
|
||||
self._scores[score_type].append(score)
|
||||
|
||||
def aggregate(self):
|
||||
"""Aggregates scores previously added using add_scores.
|
||||
Returns:
|
||||
A dict mapping score_type to AggregateScore objects.
|
||||
"""
|
||||
|
||||
result = {}
|
||||
for score_type, scores in six.iteritems(self._scores):
|
||||
# Stack scores into a 2-d matrix of (sample, measure).
|
||||
score_matrix = np.vstack(tuple(scores))
|
||||
# Percentiles are returned as (interval, measure).
|
||||
percentiles = self._bootstrap_resample(score_matrix)
|
||||
# Extract the three intervals (low, mid, high).
|
||||
intervals = tuple(
|
||||
(scores[0].__class__(*percentiles[j, :]) for j in range(3)))
|
||||
result[score_type] = AggregateScore(
|
||||
low=intervals[0], mid=intervals[1], high=intervals[2])
|
||||
return result
|
||||
|
||||
def _bootstrap_resample(self, matrix):
|
||||
"""Performs bootstrap resampling on a matrix of scores.
|
||||
Args:
|
||||
matrix: A 2-d matrix of (sample, measure).
|
||||
Returns:
|
||||
A 2-d matrix of (bounds, measure). There are three bounds: low (row 0),
|
||||
mid (row 1) and high (row 2). Mid is always the mean, while low and high
|
||||
bounds are specified by self._confidence_interval (which defaults to 0.95
|
||||
meaning it will return the 2.5th and 97.5th percentiles for a 95%
|
||||
confidence interval on the mean).
|
||||
"""
|
||||
|
||||
# Matrix of (bootstrap sample, measure).
|
||||
sample_mean = np.zeros((self._n_samples, matrix.shape[1]))
|
||||
for i in range(self._n_samples):
|
||||
sample_idx = np.random.choice(
|
||||
np.arange(matrix.shape[0]), size=matrix.shape[0])
|
||||
sample = matrix[sample_idx, :]
|
||||
sample_mean[i, :] = np.mean(sample, axis=0)
|
||||
|
||||
# Take percentiles on the estimate of the mean using bootstrap samples.
|
||||
# Final result is a (bounds, measure) matrix.
|
||||
percentile_delta = (1 - self._confidence_interval) / 2
|
||||
q = 100 * np.array([percentile_delta, 0.5, 1 - percentile_delta])
|
||||
return np.percentile(sample_mean, q, axis=0)
|
||||
|
||||
|
||||
def func_fmeasure(precision, recall):
|
||||
"""Computes f-measure given precision and recall values."""
|
||||
|
||||
if precision + recall > 0:
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
else:
|
||||
return 0.0
|
|
@ -0,0 +1,92 @@
|
|||
""" Official evaluation script for v1.1 of the SQuAD dataset. """
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def f1_score(prediction, ground_truth):
|
||||
prediction_tokens = normalize_answer(prediction).split()
|
||||
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def exact_match_score(prediction, ground_truth):
|
||||
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
scores_for_ground_truths = []
|
||||
for ground_truth in ground_truths:
|
||||
score = metric_fn(prediction, ground_truth)
|
||||
scores_for_ground_truths.append(score)
|
||||
return max(scores_for_ground_truths)
|
||||
|
||||
|
||||
def evaluate(dataset, predictions):
|
||||
f1 = exact_match = total = 0
|
||||
for article in dataset:
|
||||
for paragraph in article["paragraphs"]:
|
||||
for qa in paragraph["qas"]:
|
||||
total += 1
|
||||
if qa["id"] not in predictions:
|
||||
message = "Unanswered question " + qa["id"] + " will receive score 0."
|
||||
print(message, file=sys.stderr)
|
||||
continue
|
||||
ground_truths = list(map(lambda x: x["text"], qa["answers"]))
|
||||
prediction = predictions[qa["id"]]
|
||||
exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
|
||||
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
|
||||
|
||||
exact_match = 100.0 * exact_match / total
|
||||
f1 = 100.0 * f1 / total
|
||||
|
||||
return {"exact_match": exact_match, "f1": f1}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
expected_version = "1.1"
|
||||
parser = argparse.ArgumentParser(description="Evaluation for SQuAD " + expected_version)
|
||||
parser.add_argument("dataset_file", help="Dataset file")
|
||||
parser.add_argument("prediction_file", help="Prediction File")
|
||||
args = parser.parse_args()
|
||||
with open(args.dataset_file) as dataset_file:
|
||||
dataset_json = json.load(dataset_file)
|
||||
if dataset_json["version"] != expected_version:
|
||||
print(
|
||||
"Evaluation expects v-" + expected_version + ", but got dataset with v-" + dataset_json["version"],
|
||||
file=sys.stderr,
|
||||
)
|
||||
dataset = dataset_json["data"]
|
||||
with open(args.prediction_file) as prediction_file:
|
||||
predictions = json.load(prediction_file)
|
||||
print(json.dumps(evaluate(dataset, predictions)))
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2020 The HuggingFace Datasets Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" SQuAD metric. """
|
||||
|
||||
import datasets
|
||||
|
||||
from .evaluate import evaluate
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@inproceedings{Rajpurkar2016SQuAD10,
|
||||
title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text},
|
||||
author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang},
|
||||
booktitle={EMNLP},
|
||||
year={2016}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """
|
||||
This metric wrap the official scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD).
|
||||
|
||||
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by
|
||||
crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span,
|
||||
from the corresponding reading passage, or the question might be unanswerable.
|
||||
"""
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Computes SQuAD scores (F1 and EM).
|
||||
Args:
|
||||
predictions: List of question-answers dictionaries with the following key-values:
|
||||
- 'id': id of the question-answer pair as given in the references (see below)
|
||||
- 'prediction_text': the text of the answer
|
||||
references: List of question-answers dictionaries with the following key-values:
|
||||
- 'id': id of the question-answer pair (see above),
|
||||
- 'answers': a Dict in the SQuAD dataset format
|
||||
{
|
||||
'text': list of possible texts for the answer, as a list of strings
|
||||
'answer_start': list of start positions for the answer, as a list of ints
|
||||
}
|
||||
Note that answer_start values are not taken into account to compute the metric.
|
||||
Returns:
|
||||
'exact_match': Exact match (the normalized answer exactly match the gold answer)
|
||||
'f1': The F-score of predicted tokens versus the gold answer
|
||||
Examples:
|
||||
|
||||
>>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22'}]
|
||||
>>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}]
|
||||
>>> squad_metric = datasets.load_metric("squad")
|
||||
>>> results = squad_metric.compute(predictions=predictions, references=references)
|
||||
>>> print(results)
|
||||
{'exact_match': 100.0, 'f1': 100.0}
|
||||
"""
|
||||
|
||||
|
||||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class Squad(datasets.Metric):
|
||||
def _info(self):
|
||||
return datasets.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"predictions": {"id": datasets.Value("string"), "prediction_text": datasets.Value("string")},
|
||||
"references": {
|
||||
"id": datasets.Value("string"),
|
||||
"answers": datasets.features.Sequence(
|
||||
{
|
||||
"text": datasets.Value("string"),
|
||||
"answer_start": datasets.Value("int32"),
|
||||
}
|
||||
),
|
||||
},
|
||||
}
|
||||
),
|
||||
codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
|
||||
reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
|
||||
)
|
||||
|
||||
def _compute(self, predictions, references):
|
||||
pred_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions}
|
||||
dataset = [
|
||||
{
|
||||
"paragraphs": [
|
||||
{
|
||||
"qas": [
|
||||
{
|
||||
"answers": [{"text": answer_text} for answer_text in ref["answers"]["text"]],
|
||||
"id": ref["id"],
|
||||
}
|
||||
for ref in references
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
score = evaluate(dataset=dataset, predictions=pred_dict)
|
||||
return score
|
|
@ -0,0 +1,330 @@
|
|||
"""Official evaluation script for SQuAD version 2.0.
|
||||
|
||||
In addition to basic functionality, we also compute additional statistics and
|
||||
plot precision-recall curves if an additional na_prob.json file is provided.
|
||||
This file is expected to map question ID's to the model's predicted probability
|
||||
that a question is unanswerable.
|
||||
"""
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
OPTS = None
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("Official evaluation script for SQuAD version 2.0.")
|
||||
parser.add_argument("data_file", metavar="data.json", help="Input data JSON file.")
|
||||
parser.add_argument("pred_file", metavar="pred.json", help="Model predictions.")
|
||||
parser.add_argument(
|
||||
"--out-file", "-o", metavar="eval.json", help="Write accuracy metrics to file (default is stdout)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--na-prob-file", "-n", metavar="na_prob.json", help="Model estimates of probability of no answer."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--na-prob-thresh",
|
||||
"-t",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help='Predict "" if no-answer probability exceeds this (default = 1.0).',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-image-dir", "-p", metavar="out_images", default=None, help="Save precision-recall curves to directory."
|
||||
)
|
||||
parser.add_argument("--verbose", "-v", action="store_true")
|
||||
if len(sys.argv) == 1:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_qid_to_has_ans(dataset):
|
||||
qid_to_has_ans = {}
|
||||
for article in dataset:
|
||||
for p in article["paragraphs"]:
|
||||
for qa in p["qas"]:
|
||||
qid_to_has_ans[qa["id"]] = bool(qa["answers"]["text"])
|
||||
return qid_to_has_ans
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
def remove_articles(text):
|
||||
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
||||
return re.sub(regex, " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def get_tokens(s):
|
||||
if not s:
|
||||
return []
|
||||
return normalize_answer(s).split()
|
||||
|
||||
|
||||
def compute_exact(a_gold, a_pred):
|
||||
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
|
||||
|
||||
|
||||
def compute_f1(a_gold, a_pred):
|
||||
gold_toks = get_tokens(a_gold)
|
||||
pred_toks = get_tokens(a_pred)
|
||||
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
|
||||
num_same = sum(common.values())
|
||||
if len(gold_toks) == 0 or len(pred_toks) == 0:
|
||||
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
||||
return int(gold_toks == pred_toks)
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(pred_toks)
|
||||
recall = 1.0 * num_same / len(gold_toks)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def get_raw_scores(dataset, preds):
|
||||
exact_scores = {}
|
||||
f1_scores = {}
|
||||
for article in dataset:
|
||||
for p in article["paragraphs"]:
|
||||
for qa in p["qas"]:
|
||||
qid = qa["id"]
|
||||
if qid not in preds:
|
||||
print(f"Missing prediction for {qid}")
|
||||
continue
|
||||
a_pred = preds[qid]
|
||||
|
||||
gold_answers = [t for t in qa["answers"]["text"] if normalize_answer(t)]
|
||||
if not gold_answers:
|
||||
# For unanswerable questions, only correct answer is empty string
|
||||
gold_answers = [""]
|
||||
if a_pred != "":
|
||||
exact_scores[qid] = 0
|
||||
f1_scores[qid] = 0
|
||||
else:
|
||||
exact_scores[qid] = 1
|
||||
f1_scores[qid] = 1
|
||||
else:
|
||||
# Take max over all gold answers
|
||||
exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
|
||||
f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
|
||||
return exact_scores, f1_scores
|
||||
|
||||
|
||||
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
|
||||
new_scores = {}
|
||||
for qid, s in scores.items():
|
||||
pred_na = na_probs[qid] > na_prob_thresh
|
||||
if pred_na:
|
||||
new_scores[qid] = float(not qid_to_has_ans[qid])
|
||||
else:
|
||||
new_scores[qid] = s
|
||||
return new_scores
|
||||
|
||||
|
||||
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
||||
if not qid_list:
|
||||
total = len(exact_scores)
|
||||
return collections.OrderedDict(
|
||||
[
|
||||
("exact", 100.0 * sum(exact_scores.values()) / total),
|
||||
("f1", 100.0 * sum(f1_scores.values()) / total),
|
||||
("total", total),
|
||||
]
|
||||
)
|
||||
else:
|
||||
total = len(qid_list)
|
||||
return collections.OrderedDict(
|
||||
[
|
||||
("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
||||
("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
||||
("total", total),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def merge_eval(main_eval, new_eval, prefix):
|
||||
for k in new_eval:
|
||||
main_eval[f"{prefix}_{k}"] = new_eval[k]
|
||||
|
||||
|
||||
def plot_pr_curve(precisions, recalls, out_image, title):
|
||||
plt.step(recalls, precisions, color="b", alpha=0.2, where="post")
|
||||
plt.fill_between(recalls, precisions, step="post", alpha=0.2, color="b")
|
||||
plt.xlabel("Recall")
|
||||
plt.ylabel("Precision")
|
||||
plt.xlim([0.0, 1.05])
|
||||
plt.ylim([0.0, 1.05])
|
||||
plt.title(title)
|
||||
plt.savefig(out_image)
|
||||
plt.clf()
|
||||
|
||||
|
||||
def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None):
|
||||
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||
true_pos = 0.0
|
||||
cur_p = 1.0
|
||||
cur_r = 0.0
|
||||
precisions = [1.0]
|
||||
recalls = [0.0]
|
||||
avg_prec = 0.0
|
||||
for i, qid in enumerate(qid_list):
|
||||
if qid_to_has_ans[qid]:
|
||||
true_pos += scores[qid]
|
||||
cur_p = true_pos / float(i + 1)
|
||||
cur_r = true_pos / float(num_true_pos)
|
||||
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]:
|
||||
# i.e., if we can put a threshold after this point
|
||||
avg_prec += cur_p * (cur_r - recalls[-1])
|
||||
precisions.append(cur_p)
|
||||
recalls.append(cur_r)
|
||||
if out_image:
|
||||
plot_pr_curve(precisions, recalls, out_image, title)
|
||||
return {"ap": 100.0 * avg_prec}
|
||||
|
||||
|
||||
def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, out_image_dir):
|
||||
if out_image_dir and not os.path.exists(out_image_dir):
|
||||
os.makedirs(out_image_dir)
|
||||
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
|
||||
if num_true_pos == 0:
|
||||
return
|
||||
pr_exact = make_precision_recall_eval(
|
||||
exact_raw,
|
||||
na_probs,
|
||||
num_true_pos,
|
||||
qid_to_has_ans,
|
||||
out_image=os.path.join(out_image_dir, "pr_exact.png"),
|
||||
title="Precision-Recall curve for Exact Match score",
|
||||
)
|
||||
pr_f1 = make_precision_recall_eval(
|
||||
f1_raw,
|
||||
na_probs,
|
||||
num_true_pos,
|
||||
qid_to_has_ans,
|
||||
out_image=os.path.join(out_image_dir, "pr_f1.png"),
|
||||
title="Precision-Recall curve for F1 score",
|
||||
)
|
||||
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
|
||||
pr_oracle = make_precision_recall_eval(
|
||||
oracle_scores,
|
||||
na_probs,
|
||||
num_true_pos,
|
||||
qid_to_has_ans,
|
||||
out_image=os.path.join(out_image_dir, "pr_oracle.png"),
|
||||
title="Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)",
|
||||
)
|
||||
merge_eval(main_eval, pr_exact, "pr_exact")
|
||||
merge_eval(main_eval, pr_f1, "pr_f1")
|
||||
merge_eval(main_eval, pr_oracle, "pr_oracle")
|
||||
|
||||
|
||||
def histogram_na_prob(na_probs, qid_list, image_dir, name):
|
||||
if not qid_list:
|
||||
return
|
||||
x = [na_probs[k] for k in qid_list]
|
||||
weights = np.ones_like(x) / float(len(x))
|
||||
plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
|
||||
plt.xlabel("Model probability of no-answer")
|
||||
plt.ylabel("Proportion of dataset")
|
||||
plt.title(f"Histogram of no-answer probability: {name}")
|
||||
plt.savefig(os.path.join(image_dir, f"na_prob_hist_{name}.png"))
|
||||
plt.clf()
|
||||
|
||||
|
||||
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
||||
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
||||
cur_score = num_no_ans
|
||||
best_score = cur_score
|
||||
best_thresh = 0.0
|
||||
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||
for i, qid in enumerate(qid_list):
|
||||
if qid not in scores:
|
||||
continue
|
||||
if qid_to_has_ans[qid]:
|
||||
diff = scores[qid]
|
||||
else:
|
||||
if preds[qid]:
|
||||
diff = -1
|
||||
else:
|
||||
diff = 0
|
||||
cur_score += diff
|
||||
if cur_score > best_score:
|
||||
best_score = cur_score
|
||||
best_thresh = na_probs[qid]
|
||||
return 100.0 * best_score / len(scores), best_thresh
|
||||
|
||||
|
||||
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||
main_eval["best_exact"] = best_exact
|
||||
main_eval["best_exact_thresh"] = exact_thresh
|
||||
main_eval["best_f1"] = best_f1
|
||||
main_eval["best_f1_thresh"] = f1_thresh
|
||||
|
||||
|
||||
def main():
|
||||
with open(OPTS.data_file) as f:
|
||||
dataset_json = json.load(f)
|
||||
dataset = dataset_json["data"]
|
||||
with open(OPTS.pred_file) as f:
|
||||
preds = json.load(f)
|
||||
if OPTS.na_prob_file:
|
||||
with open(OPTS.na_prob_file) as f:
|
||||
na_probs = json.load(f)
|
||||
else:
|
||||
na_probs = {k: 0.0 for k in preds}
|
||||
qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
|
||||
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
||||
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
||||
exact_raw, f1_raw = get_raw_scores(dataset, preds)
|
||||
exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh)
|
||||
f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh)
|
||||
out_eval = make_eval_dict(exact_thresh, f1_thresh)
|
||||
if has_ans_qids:
|
||||
has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
|
||||
merge_eval(out_eval, has_ans_eval, "HasAns")
|
||||
if no_ans_qids:
|
||||
no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
|
||||
merge_eval(out_eval, no_ans_eval, "NoAns")
|
||||
if OPTS.na_prob_file:
|
||||
find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
|
||||
if OPTS.na_prob_file and OPTS.out_image_dir:
|
||||
run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, OPTS.out_image_dir)
|
||||
histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, "hasAns")
|
||||
histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, "noAns")
|
||||
if OPTS.out_file:
|
||||
with open(OPTS.out_file, "w") as f:
|
||||
json.dump(out_eval, f)
|
||||
else:
|
||||
print(json.dumps(out_eval, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
OPTS = parse_args()
|
||||
if OPTS.out_image_dir:
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
main()
|
|
@ -0,0 +1,139 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Datasets Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" SQuAD v2 metric. """
|
||||
|
||||
import datasets
|
||||
|
||||
from .evaluate import (
|
||||
apply_no_ans_threshold,
|
||||
find_all_best_thresh,
|
||||
get_raw_scores,
|
||||
make_eval_dict,
|
||||
make_qid_to_has_ans,
|
||||
merge_eval,
|
||||
)
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@inproceedings{Rajpurkar2016SQuAD10,
|
||||
title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text},
|
||||
author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang},
|
||||
booktitle={EMNLP},
|
||||
year={2016}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """
|
||||
This metric wrap the official scoring script for version 2 of the Stanford Question
|
||||
Answering Dataset (SQuAD).
|
||||
|
||||
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by
|
||||
crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span,
|
||||
from the corresponding reading passage, or the question might be unanswerable.
|
||||
|
||||
SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable questions
|
||||
written adversarially by crowdworkers to look similar to answerable ones.
|
||||
To do well on SQuAD2.0, systems must not only answer questions when possible, but also
|
||||
determine when no answer is supported by the paragraph and abstain from answering.
|
||||
"""
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Computes SQuAD v2 scores (F1 and EM).
|
||||
Args:
|
||||
predictions: List of triple for question-answers to score with the following elements:
|
||||
- the question-answer 'id' field as given in the references (see below)
|
||||
- the text of the answer
|
||||
- the probability that the question has no answer
|
||||
references: List of question-answers dictionaries with the following key-values:
|
||||
- 'id': id of the question-answer pair (see above),
|
||||
- 'answers': a list of Dict {'text': text of the answer as a string}
|
||||
no_answer_threshold: float
|
||||
Probability threshold to decide that a question has no answer.
|
||||
Returns:
|
||||
'exact': Exact match (the normalized answer exactly match the gold answer)
|
||||
'f1': The F-score of predicted tokens versus the gold answer
|
||||
'total': Number of score considered
|
||||
'HasAns_exact': Exact match (the normalized answer exactly match the gold answer)
|
||||
'HasAns_f1': The F-score of predicted tokens versus the gold answer
|
||||
'HasAns_total': Number of score considered
|
||||
'NoAns_exact': Exact match (the normalized answer exactly match the gold answer)
|
||||
'NoAns_f1': The F-score of predicted tokens versus the gold answer
|
||||
'NoAns_total': Number of score considered
|
||||
'best_exact': Best exact match (with varying threshold)
|
||||
'best_exact_thresh': No-answer probability threshold associated to the best exact match
|
||||
'best_f1': Best F1 (with varying threshold)
|
||||
'best_f1_thresh': No-answer probability threshold associated to the best F1
|
||||
Examples:
|
||||
|
||||
>>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22', 'no_answer_probability': 0.}]
|
||||
>>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}]
|
||||
>>> squad_v2_metric = datasets.load_metric("squad_v2")
|
||||
>>> results = squad_v2_metric.compute(predictions=predictions, references=references)
|
||||
>>> print(results)
|
||||
{'exact': 100.0, 'f1': 100.0, 'total': 1, 'HasAns_exact': 100.0, 'HasAns_f1': 100.0, 'HasAns_total': 1, 'best_exact': 100.0, 'best_exact_thresh': 0.0, 'best_f1': 100.0, 'best_f1_thresh': 0.0}
|
||||
"""
|
||||
|
||||
|
||||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class SquadV2Local(datasets.Metric):
|
||||
def _info(self):
|
||||
return datasets.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"predictions": {
|
||||
"id": datasets.Value("string"),
|
||||
"prediction_text": datasets.Value("string"),
|
||||
"no_answer_probability": datasets.Value("float32"),
|
||||
},
|
||||
"references": {
|
||||
"id": datasets.Value("string"),
|
||||
"answers": datasets.features.Sequence(
|
||||
{"text": datasets.Value("string"), "answer_start": datasets.Value("int32")}
|
||||
),
|
||||
},
|
||||
}
|
||||
),
|
||||
codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
|
||||
reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
|
||||
)
|
||||
|
||||
def _compute(self, predictions, references, no_answer_threshold=1.0):
|
||||
# no_answer_probabilities = dict((p["id"], p["no_answer_probability"]) for p in predictions)
|
||||
dataset = [{"paragraphs": [{"qas": references}]}]
|
||||
predictions = dict((p["id"], p["prediction_text"]) for p in predictions)
|
||||
|
||||
# qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
|
||||
# has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
||||
# no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
||||
|
||||
exact_raw, f1_raw = get_raw_scores(dataset, predictions)
|
||||
out_eval = make_eval_dict(exact_raw, f1_raw)
|
||||
#
|
||||
#
|
||||
# exact_thresh = apply_no_ans_threshold(exact_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold)
|
||||
# f1_thresh = apply_no_ans_threshold(f1_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold)
|
||||
# out_eval = make_eval_dict(exact_thresh, f1_thresh)
|
||||
#
|
||||
# if has_ans_qids:
|
||||
# has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
|
||||
# merge_eval(out_eval, has_ans_eval, "HasAns")
|
||||
# if no_ans_qids:
|
||||
# no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
|
||||
# merge_eval(out_eval, no_ans_eval, "NoAns")
|
||||
# find_all_best_thresh(out_eval, predictions, exact_raw, f1_raw, no_answer_probabilities, qid_to_has_ans)
|
||||
return dict(out_eval)
|
|
@ -0,0 +1,192 @@
|
|||
import sys
|
||||
sys.path.append('../data_process')
|
||||
from typing import List, Optional, Tuple
|
||||
from QAInput import SimpleQAInput as QAInput
|
||||
#from QAInput import QAInputNoPrompt as QAInput
|
||||
|
||||
def preprocess_sqaud_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_sqaud_abstractive_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_boolq_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
question_column, context_column, answer_column = 'question', 'passage', 'answer'
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [str(ans) for ans in answers] #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
# print(inputs,targets)
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_boolq_batch_pretrain(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0].capitalize() if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
# print(inputs,targets)
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_narrativeqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = [exp['summary']['text'] for exp in examples['document']]
|
||||
questions = [exp['text'] for exp in examples['question']]
|
||||
answers = [ans[0]['text'] for ans in examples['answers']]
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = answers #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_narrativeqa_batch_pretrain(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_drop_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = examples['passage']
|
||||
questions = examples['question']
|
||||
answers = examples['answers']
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_race_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = examples['article']
|
||||
questions = examples['question']
|
||||
all_options = examples['options']
|
||||
answers = examples['answer']
|
||||
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}; D. {options[3]}' for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
|
||||
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3 }
|
||||
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_newsqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
|
||||
# inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
|
||||
def preprocess_ropes_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
backgrounds = examples["background"]
|
||||
situations = examples["situation"]
|
||||
answers = examples[answer_column]
|
||||
|
||||
inputs = [QAInput.qg_input_extractive_qa(" ".join([background, situation]), question) for question, background, situation in zip(questions, backgrounds, situations)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_openbookqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples['question_stem']
|
||||
all_options = examples['choices']
|
||||
answers = examples['answerKey']
|
||||
options_texts = [f"options: A. {options['text'][0]}; B. {options['text'][1]}; C. {options['text'][2]}; D. {options['text'][3]}" for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc("", question, ops) for question, ops in zip(questions, options_texts)]
|
||||
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
|
||||
targets = [options['text'][ans_map[answer]] for options, answer in zip(all_options, answers)]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_social_iqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = examples['article']
|
||||
questions = examples['question']
|
||||
all_options = examples['options']
|
||||
answers = examples['answer']
|
||||
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
|
||||
ans_map = {'A': 0, 'B': 1, 'C': 2,}
|
||||
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_dream_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = [" ".join(dialogue) for dialogue in examples['dialogue']]
|
||||
questions = examples['question']
|
||||
all_options = examples['choice']
|
||||
answers = examples['answer']
|
||||
answer_idxs = [options.index(answer) for answer, options in zip(answers, all_options)]
|
||||
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
|
||||
targets = answers
|
||||
return inputs, targets
|
|
@ -0,0 +1,21 @@
|
|||
#ProQA
|
||||
#[paper]https://arxiv.org/abs/2205.04040
|
||||
class StructuralQAInput:
|
||||
@classmethod
|
||||
def qg_input(cls, context, question, options=None):
|
||||
question = question
|
||||
source_text = f'[QUESTION] {question}. [CONTEXT] {context} the answer is: '
|
||||
return source_text
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class SimpleQAInput:
|
||||
@classmethod
|
||||
def qg_input(cls, context, question, options=None):
|
||||
question = question
|
||||
source_text = f'{question} \\n {context}'
|
||||
return source_text
|
||||
|
|
@ -0,0 +1,192 @@
|
|||
import sys
|
||||
sys.path.append('../data_process')
|
||||
from typing import List, Optional, Tuple
|
||||
from QAInput import StructuralQAInput as QAInput
|
||||
#from QAInput import QAInputNoPrompt as QAInput
|
||||
|
||||
def preprocess_sqaud_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_sqaud_abstractive_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_boolq_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
question_column, context_column, answer_column = 'question', 'passage', 'answer'
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [str(ans) for ans in answers] #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
# print(inputs,targets)
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_boolq_batch_pretrain(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0].capitalize() if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
# print(inputs,targets)
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_narrativeqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = [exp['summary']['text'] for exp in examples['document']]
|
||||
questions = [exp['text'] for exp in examples['question']]
|
||||
answers = [ans[0]['text'] for ans in examples['answers']]
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = answers #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_narrativeqa_batch_pretrain(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_drop_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = examples['passage']
|
||||
questions = examples['question']
|
||||
answers = examples['answers']
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_race_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = examples['article']
|
||||
questions = examples['question']
|
||||
all_options = examples['options']
|
||||
answers = examples['answer']
|
||||
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}; D. {options[3]}' for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
|
||||
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3 }
|
||||
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_newsqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
|
||||
# inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
|
||||
def preprocess_ropes_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
backgrounds = examples["background"]
|
||||
situations = examples["situation"]
|
||||
answers = examples[answer_column]
|
||||
|
||||
inputs = [QAInput.qg_input_extractive_qa(" ".join([background, situation]), question) for question, background, situation in zip(questions, backgrounds, situations)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_openbookqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples['question_stem']
|
||||
all_options = examples['choices']
|
||||
answers = examples['answerKey']
|
||||
options_texts = [f"options: A. {options['text'][0]}; B. {options['text'][1]}; C. {options['text'][2]}; D. {options['text'][3]}" for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc("", question, ops) for question, ops in zip(questions, options_texts)]
|
||||
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
|
||||
targets = [options['text'][ans_map[answer]] for options, answer in zip(all_options, answers)]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_social_iqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = examples['article']
|
||||
questions = examples['question']
|
||||
all_options = examples['options']
|
||||
answers = examples['answer']
|
||||
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
|
||||
ans_map = {'A': 0, 'B': 1, 'C': 2,}
|
||||
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_dream_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = [" ".join(dialogue) for dialogue in examples['dialogue']]
|
||||
questions = examples['question']
|
||||
all_options = examples['choice']
|
||||
answers = examples['answer']
|
||||
answer_idxs = [options.index(answer) for answer, options in zip(answers, all_options)]
|
||||
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
|
||||
targets = answers
|
||||
return inputs, targets
|
|
@ -0,0 +1,801 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
import copy,random
|
||||
import sys
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from sklearn.cluster import KMeans
|
||||
# from data import QAData
|
||||
sys.path.append("../")
|
||||
from models.metadeca import T5ForConditionalGeneration as PromptT5
|
||||
from metrics import compute_metrics
|
||||
from downstreamdeca.simple_processors import *
|
||||
from downstreamdeca.l2ptrainer import QuestionAnsweringTrainer
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
M2M100Tokenizer,
|
||||
MBart50Tokenizer,
|
||||
EvalPrediction,
|
||||
MBart50TokenizerFast,
|
||||
MBartTokenizer,
|
||||
MBartTokenizerFast,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
default_data_collator,
|
||||
set_seed,
|
||||
)
|
||||
from pathlib import Path
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
|
||||
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
|
||||
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
|
||||
format2dataset = {
|
||||
'extractive':['squad','extractive','newsqa','quoref'], #,'ropes'
|
||||
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
|
||||
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
|
||||
}
|
||||
seed_datasets = ['race','narrativeqa','squad']
|
||||
dataset2format= {}
|
||||
task2id = {}
|
||||
for k,vs in format2dataset.items():
|
||||
for v in vs:
|
||||
dataset2format[v] = k
|
||||
|
||||
|
||||
task2id = {"wikisql":0,"squad":1,"srl":2,"sst":3,"woz.en":4,"iwslt.en.de":5,"cnn_dailymail":6,"multinli.in.out":7,"zre":8}
|
||||
task2format={"wikisql":1,"squad":0,"srl":0,"sst":2,"woz.en":1,"iwslt.en.de":0,"cnn_dailymail":1,"multinli.in.out":2,"zre":0}
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
fix_t5: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to fix the main parameters of t5"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
|
||||
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
|
||||
"a jsonlines file."
|
||||
},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
after_train: Optional[str] = field(default=None)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=384,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||
"efficient on GPU but very bad for TPU."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||
"which is used during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
},
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
)
|
||||
forced_bos_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
|
||||
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
|
||||
"needs to be the target language token.(Usually it is the target language token)"
|
||||
},
|
||||
)
|
||||
do_lowercase: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
'help':'Whether to process input into lowercase'
|
||||
}
|
||||
)
|
||||
append_another_bos: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
'help':'Whether to add another bos'
|
||||
}
|
||||
)
|
||||
context_column: Optional[str] = field(
|
||||
default="context",
|
||||
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
|
||||
)
|
||||
question_column: Optional[str] = field(
|
||||
default="question",
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
answer_column: Optional[str] = field(
|
||||
default="answers",
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
max_task_num: Optional[int] = field(
|
||||
default=30,
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
qa_task_type_num: Optional[int] = field(
|
||||
default=4,
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
prompt_number: Optional[int] = field(
|
||||
default=100,
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
add_task_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
reload_from_trained_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to reload prompt from trained prompt"},
|
||||
)
|
||||
trained_prompt_path:Optional[str] = field(
|
||||
default = None,
|
||||
metadata={
|
||||
'help':'the path storing trained prompt embedding'
|
||||
}
|
||||
)
|
||||
load_from_format_task_id: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
|
||||
)
|
||||
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
# elif self.source_lang is None or self.target_lang is None:
|
||||
# raise ValueError("Need to specify the source language and the target language.")
|
||||
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
# assert extension == "json", "`train_file` should be a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
# assert extension == "json", "`validation_file` should be a json file."
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
|
||||
# +
|
||||
def main():
|
||||
|
||||
|
||||
|
||||
def preprocess_validation_function(examples):
|
||||
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
model_inputs["example_id"] = []
|
||||
|
||||
for i in range(len(model_inputs["input_ids"])):
|
||||
# One example can give several spans, this is the index of the example containing this span of text.
|
||||
sample_index = i #sample_mapping[i]
|
||||
model_inputs["example_id"].append(examples["id"][sample_index])
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
|
||||
def save_prompt_embedding(model,path):
|
||||
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
|
||||
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
|
||||
prompt_path = os.path.join(path,'prompt_embedding_info')
|
||||
torch.save(save_prompt_info,prompt_path)
|
||||
logger.info(f'Saving prompt embedding information to {prompt_path}')
|
||||
|
||||
|
||||
|
||||
def preprocess_function(examples):
|
||||
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
def save_load_diverse_sample(model,trainset):
|
||||
with torch.no_grad():
|
||||
device = 'cuda:0'
|
||||
search_upbound = len(trainset)//4
|
||||
query_idxs = [None]*30
|
||||
keys = model.encoder.domain_keys
|
||||
for idx,item in enumerate(trainset.select(range(search_upbound))):
|
||||
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
|
||||
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
|
||||
return_dict=True)
|
||||
result = torch.matmul(query,keys.t())
|
||||
result = torch.topk(result,5).indices[0].cpu().numpy().tolist()
|
||||
key_sel = None
|
||||
for key_idx in result:
|
||||
if query_idxs[key_idx] is None or len(query_idxs[key_idx])<3:
|
||||
key_sel = key_idx
|
||||
break
|
||||
if key_sel is not None:
|
||||
if query_idxs[key_sel] is None:
|
||||
query_idxs[key_sel] = [idx]
|
||||
else:
|
||||
query_idxs[key_sel].append(idx)
|
||||
total_idxs = []
|
||||
for item in query_idxs:
|
||||
try:
|
||||
total_idxs.extend(item[:3])
|
||||
except:
|
||||
total_idxs.extend(random.sample(list(range(search_upbound,len(trainset))),3))
|
||||
total_idxs = list(set(total_idxs))
|
||||
total_idxs = random.sample(total_idxs,50)
|
||||
sub_set = trainset.select(total_idxs)
|
||||
|
||||
features = []
|
||||
for idx,item in enumerate(sub_set):
|
||||
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
|
||||
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
|
||||
return_dict=True)
|
||||
features.append(query.detach().cpu().numpy())
|
||||
|
||||
|
||||
|
||||
return sub_set,features
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def dataset_name_to_func(dataset_name):
|
||||
mapping = {
|
||||
'squad': preprocess_sqaud_batch,
|
||||
'squad_v2': preprocess_sqaud_batch,
|
||||
'boolq': preprocess_boolq_batch,
|
||||
'narrativeqa': preprocess_narrativeqa_batch,
|
||||
'race': preprocess_race_batch,
|
||||
'newsqa': preprocess_newsqa_batch,
|
||||
'quoref': preprocess_sqaud_batch,
|
||||
'ropes': preprocess_ropes_batch,
|
||||
'drop': preprocess_drop_batch,
|
||||
'nqopen': preprocess_sqaud_abstractive_batch,
|
||||
# 'multirc': preprocess_boolq_batch,
|
||||
'boolq_np': preprocess_boolq_batch,
|
||||
'openbookqa': preprocess_openbookqa_batch,
|
||||
'mctest': preprocess_race_batch,
|
||||
'social_iqa': preprocess_social_iqa_batch,
|
||||
'dream': preprocess_dream_batch,
|
||||
}
|
||||
return mapping[dataset_name]
|
||||
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
dataset_name_to_metric = {
|
||||
'squad': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
|
||||
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
|
||||
'boolq': 'metric/accuracy.py',
|
||||
'narrativeqa': 'metric/rouge_local/rouge_metric.py',
|
||||
'race': 'metric/accuracy.py',
|
||||
'quoref': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'ropes': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'drop': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'nqopen': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'boolq_np': 'metric/accuracy.py',
|
||||
'openbookqa': 'metric/accuracy.py',
|
||||
'mctest': 'metric/accuracy.py',
|
||||
'social_iqa': 'metric/accuracy.py',
|
||||
'dream': 'metric/accuracy.py',
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
if data_args.source_prefix is None and model_args.model_name_or_path in [
|
||||
"t5-small",
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
"t5-3b",
|
||||
"t5-11b",
|
||||
]:
|
||||
logger.warning(
|
||||
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
|
||||
"`--source_prefix 'translate English to German: ' `"
|
||||
)
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
|
||||
# '[OPTIONS]'])
|
||||
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
|
||||
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
|
||||
'[OPTIONS]']}
|
||||
|
||||
tokenizer.add_tokens(tokens_to_add)
|
||||
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
added_tokens = tokenizer.get_added_vocab()
|
||||
logger.info('Added tokens: {}'.format(added_tokens))
|
||||
|
||||
model = PromptT5.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
# task_num = data_args.max_task_num,
|
||||
# prompt_num = data_args.prompt_number,
|
||||
# format_num = data_args.qa_task_type_num,
|
||||
# add_task_prompt = False
|
||||
)
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
|
||||
#reload format specific task-prompt for newly involved task
|
||||
#format_prompts###task_promptsf
|
||||
|
||||
data_args.reload_from_trained_prompt = False#@
|
||||
data_args.load_from_format_task_id = False#@
|
||||
### before pretrain come !!!!!!
|
||||
|
||||
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
|
||||
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
||||
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
||||
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
|
||||
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
|
||||
# print(dataset2format[data_args.dataset_name])
|
||||
# print(data_args.dataset_name)
|
||||
elif data_args.reload_from_trained_prompt:
|
||||
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
|
||||
prompt_info = torch.load(data_args.trained_prompt_path)
|
||||
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
|
||||
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
|
||||
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
||||
|
||||
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
||||
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
|
||||
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
|
||||
format_id = format2id[dataset2format[data_args.dataset_name]]
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
|
||||
logger.info(
|
||||
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
|
||||
|
||||
# Set decoder_start_token_id
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||
if isinstance(tokenizer, MBartTokenizer):
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
|
||||
else:
|
||||
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
|
||||
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
|
||||
|
||||
if training_args.local_rank == -1 or training_args.no_cuda:
|
||||
device = torch.device("cuda")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Temporarily set max_target_length for training.
|
||||
max_target_length = data_args.max_target_length
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||
logger.warning(
|
||||
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
|
||||
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
|
||||
)
|
||||
|
||||
|
||||
question_column = data_args.question_column
|
||||
context_column = data_args.context_column
|
||||
answer_column = data_args.answer_column
|
||||
# import random
|
||||
if data_args.max_source_length > tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
|
||||
# Data collator
|
||||
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
if data_args.pad_to_max_length:
|
||||
data_collator = default_data_collator
|
||||
else:
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
model=model,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
)
|
||||
|
||||
#start
|
||||
train_dataloaders = {}
|
||||
eval_dataloaders = {}
|
||||
replay_dataloaders = {}
|
||||
all_replay = None
|
||||
|
||||
for ds_name in ["squad","wikisql","sst","srl","woz.en"]:
|
||||
|
||||
if True:
|
||||
cur_dataset = ds_name
|
||||
train_dataloaders[ds_name] = load_from_disk("./oursdeca/{}-train.hf".format(cur_dataset)).select(range(200))
|
||||
eval_dataloaders[ds_name] = (load_from_disk("./oursdeca/{}-eval.hf".format(cur_dataset)).select(range(200)),load_from_disk("./oursdeca/{}-evalex.hf".format(cur_dataset)).select(range(200)))
|
||||
|
||||
for ds_name in ["cnn_dailymail","multinli.in.out","zre"]:
|
||||
eval_dataloaders[ds_name] = (load_from_disk("./oursdeca/{}-eval.hf".format(ds_name)),load_from_disk("./oursdeca/{}-evalex.hf".format(ds_name)))
|
||||
|
||||
pre_tasks = []
|
||||
pre_general = []
|
||||
pre_test = []
|
||||
max_length = (
|
||||
training_args.generation_max_length
|
||||
if training_args.generation_max_length is not None
|
||||
else data_args.val_max_target_length
|
||||
)
|
||||
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||
|
||||
task_sequence = ["squad","wikisql","sst","srl","woz.en"]
|
||||
# task_sequence = ["woz.en","srl","sst","wikisql","squad"]
|
||||
need_to_do_dss = []
|
||||
fileout = open("diana_log.txt",'w')
|
||||
all_replay = None
|
||||
all_features = []
|
||||
all_ids = []
|
||||
cluster_num=0
|
||||
for cur_dataset in task_sequence:
|
||||
cluster_num+=5
|
||||
pre_tasks.append(cur_dataset)
|
||||
if cur_dataset==task_sequence[-1]:
|
||||
pre_tasks.extend(["cnn_dailymail","multinli.in.out","zre"])
|
||||
data_args.dataset_name = cur_dataset
|
||||
logger.info("current_dataset:"+cur_dataset)
|
||||
|
||||
|
||||
training_args.do_train = True
|
||||
training_args.to_eval = False
|
||||
metric = load_metric("metric/squad_v1_local/squad_v1_local.py")
|
||||
if all_replay is not None:
|
||||
fused = datasets.concatenate_datasets([all_replay,train_dataloaders[cur_dataset]])
|
||||
else:
|
||||
fused = train_dataloaders[cur_dataset]
|
||||
training_args.num_train_epochs = 5
|
||||
model.encoder.reset_train_count()
|
||||
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=fused,
|
||||
eval_dataset=None,
|
||||
eval_examples=None,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
train_result = trainer.train()
|
||||
|
||||
if training_args.local_rank<=0:
|
||||
save_set,features = save_load_diverse_sample(model,train_dataloaders[cur_dataset])
|
||||
|
||||
if all_replay is None:
|
||||
all_replay = save_set
|
||||
else:
|
||||
all_replay = datasets.concatenate_datasets([all_replay,save_set])
|
||||
|
||||
if all_features==[]:
|
||||
all_features=features
|
||||
else:
|
||||
all_features.extend(features)
|
||||
np.save("./all_features.npy",np.array(all_features))
|
||||
|
||||
all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset))
|
||||
|
||||
|
||||
|
||||
if training_args.local_rank!=-1:
|
||||
torch.distributed.barrier()
|
||||
all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset))
|
||||
all_ids.extend([task2id[cur_dataset]]*50)
|
||||
all_features=np.load("./all_features.npy").tolist()
|
||||
|
||||
model.encoder.add_negs(all_ids,all_features)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
for pre_dataset in pre_tasks:
|
||||
|
||||
|
||||
data_args.dataset_name = pre_dataset
|
||||
|
||||
metric = load_metric("metric/squad_v1_local/squad_v1_local.py")
|
||||
eval_dataset,eval_examples = eval_dataloaders[pre_dataset]
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=None,
|
||||
eval_dataset=eval_dataset,
|
||||
eval_examples=eval_examples,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("*** Evaluate:{} ***".format(data_args.dataset_name))
|
||||
|
||||
max_length, num_beams, ignore_keys_for_eval = None, None, None
|
||||
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval")
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
if training_args.local_rank<=0:
|
||||
try:
|
||||
print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout)
|
||||
print(metrics,file=fileout)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
|
||||
if len(languages) > 0:
|
||||
kwargs["language"] = languages
|
||||
return None
|
||||
|
||||
|
||||
# -
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,53 @@
|
|||
function fulldata_hfdata() {
|
||||
SAVING_PATH=$1
|
||||
PRETRAIN_MODEL_PATH=$2
|
||||
TRAIN_BATCH_SIZE=$3
|
||||
EVAL_BATCH_SIZE=$4
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
mkdir -p ${SAVING_PATH}
|
||||
|
||||
python -m torch.distributed.launch --nproc_per_node=4 diana.py \
|
||||
--model_name_or_path ${PRETRAIN_MODEL_PATH} \
|
||||
--output_dir ${SAVING_PATH} \
|
||||
--dataset_name squad \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--per_device_train_batch_size ${TRAIN_BATCH_SIZE} \
|
||||
--per_device_eval_batch_size ${EVAL_BATCH_SIZE} \
|
||||
--overwrite_output_dir \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 5 \
|
||||
--warmup_ratio 0.1 \
|
||||
--logging_steps 5 \
|
||||
--learning_rate 1e-4 \
|
||||
--predict_with_generate \
|
||||
--num_beams 4 \
|
||||
--save_strategy no \
|
||||
--evaluation_strategy no \
|
||||
--weight_decay 1e-2 \
|
||||
--max_source_length 512 \
|
||||
--label_smoothing_factor 0.1 \
|
||||
--do_lowercase True \
|
||||
--load_best_model_at_end True \
|
||||
--greater_is_better True \
|
||||
--save_total_limit 10 \
|
||||
--ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log
|
||||
}
|
||||
|
||||
SAVING_PATH=$1
|
||||
TRAIN_BATCH_SIZE=$2
|
||||
EVAL_BATCH_SIZE=$3
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
MODE=try
|
||||
|
||||
|
||||
SAVING_PATH=${SAVING_PATH}/lifelong/${MODE}
|
||||
fulldata_hfdata ${SAVING_PATH} t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE}
|
|
@ -0,0 +1,6 @@
|
|||
after_train_ squad _test_ squad
|
||||
OrderedDict([('eval_em', 88.0), ('eval_nf1', 91.79166666666667), ('eval_nem', 89.5), ('eval_samples', 200)])
|
||||
after_train_ wikisql _test_ squad
|
||||
OrderedDict([('eval_em', 83.0), ('eval_nf1', 87.73095238095239), ('eval_nem', 84.5), ('eval_samples', 200)])
|
||||
after_train_ wikisql _test_ wikisql
|
||||
OrderedDict([('eval_lfem', 30.5), ('eval_em', 17.5), ('eval_nf1', 84.06011500628662), ('eval_nem', 20.0), ('eval_samples', 200)])
|
|
@ -0,0 +1,775 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
import copy,random
|
||||
import sys
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from sklearn.cluster import KMeans
|
||||
# from data import QAData
|
||||
sys.path.append("../")
|
||||
from models.metadecanometa import T5ForConditionalGeneration as PromptT5
|
||||
from metrics import compute_metrics
|
||||
from downstreamdeca.simple_processors import *
|
||||
from downstreamdeca.l2ptrainer import QuestionAnsweringTrainer
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
M2M100Tokenizer,
|
||||
MBart50Tokenizer,
|
||||
EvalPrediction,
|
||||
MBart50TokenizerFast,
|
||||
MBartTokenizer,
|
||||
MBartTokenizerFast,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
default_data_collator,
|
||||
set_seed,
|
||||
)
|
||||
from pathlib import Path
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
|
||||
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
|
||||
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
|
||||
format2dataset = {
|
||||
'extractive':['squad','extractive','newsqa','quoref'], #,'ropes'
|
||||
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
|
||||
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
|
||||
}
|
||||
seed_datasets = ['race','narrativeqa','squad']
|
||||
dataset2format= {}
|
||||
task2id = {}
|
||||
for k,vs in format2dataset.items():
|
||||
for v in vs:
|
||||
dataset2format[v] = k
|
||||
|
||||
|
||||
task2id = {"wikisql":0,"squad":1,"srl":2,"sst":3,"woz.en":4,"iwslt.en.de":5,"cnn_dailymail":6,"multinli.in.out":7,"zre":8}
|
||||
task2format={"wikisql":1,"squad":0,"srl":0,"sst":2,"woz.en":1,"iwslt.en.de":0,"cnn_dailymail":1,"multinli.in.out":2,"zre":0}
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
fix_t5: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to fix the main parameters of t5"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
|
||||
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
|
||||
"a jsonlines file."
|
||||
},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
after_train: Optional[str] = field(default=None)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=384,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||
"efficient on GPU but very bad for TPU."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||
"which is used during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
},
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
)
|
||||
forced_bos_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
|
||||
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
|
||||
"needs to be the target language token.(Usually it is the target language token)"
|
||||
},
|
||||
)
|
||||
do_lowercase: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
'help':'Whether to process input into lowercase'
|
||||
}
|
||||
)
|
||||
append_another_bos: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
'help':'Whether to add another bos'
|
||||
}
|
||||
)
|
||||
context_column: Optional[str] = field(
|
||||
default="context",
|
||||
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
|
||||
)
|
||||
question_column: Optional[str] = field(
|
||||
default="question",
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
answer_column: Optional[str] = field(
|
||||
default="answers",
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
max_task_num: Optional[int] = field(
|
||||
default=30,
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
qa_task_type_num: Optional[int] = field(
|
||||
default=4,
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
prompt_number: Optional[int] = field(
|
||||
default=100,
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
add_task_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
reload_from_trained_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to reload prompt from trained prompt"},
|
||||
)
|
||||
trained_prompt_path:Optional[str] = field(
|
||||
default = None,
|
||||
metadata={
|
||||
'help':'the path storing trained prompt embedding'
|
||||
}
|
||||
)
|
||||
load_from_format_task_id: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
|
||||
)
|
||||
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
# elif self.source_lang is None or self.target_lang is None:
|
||||
# raise ValueError("Need to specify the source language and the target language.")
|
||||
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
# assert extension == "json", "`train_file` should be a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
# assert extension == "json", "`validation_file` should be a json file."
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
|
||||
# +
|
||||
def main():
|
||||
|
||||
|
||||
|
||||
def preprocess_validation_function(examples):
|
||||
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
model_inputs["example_id"] = []
|
||||
|
||||
for i in range(len(model_inputs["input_ids"])):
|
||||
# One example can give several spans, this is the index of the example containing this span of text.
|
||||
sample_index = i #sample_mapping[i]
|
||||
model_inputs["example_id"].append(examples["id"][sample_index])
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
|
||||
def save_prompt_embedding(model,path):
|
||||
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
|
||||
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
|
||||
prompt_path = os.path.join(path,'prompt_embedding_info')
|
||||
torch.save(save_prompt_info,prompt_path)
|
||||
logger.info(f'Saving prompt embedding information to {prompt_path}')
|
||||
|
||||
|
||||
|
||||
def preprocess_function(examples):
|
||||
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
def save_load_diverse_sample(model,trainset):
|
||||
with torch.no_grad():
|
||||
device = 'cuda:0'
|
||||
|
||||
sub_set = trainset.select(range(50))
|
||||
|
||||
features = []
|
||||
for idx,item in enumerate(sub_set):
|
||||
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
|
||||
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
|
||||
return_dict=True)
|
||||
features.append(query.detach().cpu().numpy())
|
||||
|
||||
|
||||
|
||||
return sub_set,features
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def dataset_name_to_func(dataset_name):
|
||||
mapping = {
|
||||
'squad': preprocess_sqaud_batch,
|
||||
'squad_v2': preprocess_sqaud_batch,
|
||||
'boolq': preprocess_boolq_batch,
|
||||
'narrativeqa': preprocess_narrativeqa_batch,
|
||||
'race': preprocess_race_batch,
|
||||
'newsqa': preprocess_newsqa_batch,
|
||||
'quoref': preprocess_sqaud_batch,
|
||||
'ropes': preprocess_ropes_batch,
|
||||
'drop': preprocess_drop_batch,
|
||||
'nqopen': preprocess_sqaud_abstractive_batch,
|
||||
# 'multirc': preprocess_boolq_batch,
|
||||
'boolq_np': preprocess_boolq_batch,
|
||||
'openbookqa': preprocess_openbookqa_batch,
|
||||
'mctest': preprocess_race_batch,
|
||||
'social_iqa': preprocess_social_iqa_batch,
|
||||
'dream': preprocess_dream_batch,
|
||||
}
|
||||
return mapping[dataset_name]
|
||||
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
dataset_name_to_metric = {
|
||||
'squad': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
|
||||
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
|
||||
'boolq': 'metric/accuracy.py',
|
||||
'narrativeqa': 'metric/rouge_local/rouge_metric.py',
|
||||
'race': 'metric/accuracy.py',
|
||||
'quoref': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'ropes': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'drop': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'nqopen': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'boolq_np': 'metric/accuracy.py',
|
||||
'openbookqa': 'metric/accuracy.py',
|
||||
'mctest': 'metric/accuracy.py',
|
||||
'social_iqa': 'metric/accuracy.py',
|
||||
'dream': 'metric/accuracy.py',
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
if data_args.source_prefix is None and model_args.model_name_or_path in [
|
||||
"t5-small",
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
"t5-3b",
|
||||
"t5-11b",
|
||||
]:
|
||||
logger.warning(
|
||||
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
|
||||
"`--source_prefix 'translate English to German: ' `"
|
||||
)
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
|
||||
# '[OPTIONS]'])
|
||||
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
|
||||
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
|
||||
'[OPTIONS]']}
|
||||
|
||||
tokenizer.add_tokens(tokens_to_add)
|
||||
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
added_tokens = tokenizer.get_added_vocab()
|
||||
logger.info('Added tokens: {}'.format(added_tokens))
|
||||
|
||||
model = PromptT5.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
# task_num = data_args.max_task_num,
|
||||
# prompt_num = data_args.prompt_number,
|
||||
# format_num = data_args.qa_task_type_num,
|
||||
# add_task_prompt = False
|
||||
)
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
|
||||
#reload format specific task-prompt for newly involved task
|
||||
#format_prompts###task_promptsf
|
||||
|
||||
data_args.reload_from_trained_prompt = False#@
|
||||
data_args.load_from_format_task_id = False#@
|
||||
### before pretrain come !!!!!!
|
||||
|
||||
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
|
||||
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
||||
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
||||
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
|
||||
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
|
||||
# print(dataset2format[data_args.dataset_name])
|
||||
# print(data_args.dataset_name)
|
||||
elif data_args.reload_from_trained_prompt:
|
||||
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
|
||||
prompt_info = torch.load(data_args.trained_prompt_path)
|
||||
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
|
||||
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
|
||||
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
||||
|
||||
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
||||
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
|
||||
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
|
||||
format_id = format2id[dataset2format[data_args.dataset_name]]
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
|
||||
logger.info(
|
||||
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
|
||||
|
||||
# Set decoder_start_token_id
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||
if isinstance(tokenizer, MBartTokenizer):
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
|
||||
else:
|
||||
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
|
||||
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
|
||||
|
||||
if training_args.local_rank == -1 or training_args.no_cuda:
|
||||
device = torch.device("cuda")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Temporarily set max_target_length for training.
|
||||
max_target_length = data_args.max_target_length
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||
logger.warning(
|
||||
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
|
||||
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
|
||||
)
|
||||
|
||||
|
||||
question_column = data_args.question_column
|
||||
context_column = data_args.context_column
|
||||
answer_column = data_args.answer_column
|
||||
# import random
|
||||
if data_args.max_source_length > tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
|
||||
# Data collator
|
||||
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
if data_args.pad_to_max_length:
|
||||
data_collator = default_data_collator
|
||||
else:
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
model=model,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
)
|
||||
|
||||
#start
|
||||
train_dataloaders = {}
|
||||
eval_dataloaders = {}
|
||||
replay_dataloaders = {}
|
||||
all_replay = None
|
||||
|
||||
for ds_name in ["squad","wikisql","sst","srl","woz.en"]:
|
||||
|
||||
if True:
|
||||
cur_dataset = ds_name
|
||||
train_dataloaders[ds_name] = load_from_disk("./oursdecanometa/{}-train.hf".format(cur_dataset)).select(range(200))
|
||||
eval_dataloaders[ds_name] = (load_from_disk("./oursdecanometa/{}-eval.hf".format(cur_dataset)).select(range(200)),load_from_disk("./oursdecanometa/{}-evalex.hf".format(cur_dataset)).select(range(200)))
|
||||
|
||||
for ds_name in ["cnn_dailymail","multinli.in.out","zre"]:
|
||||
eval_dataloaders[ds_name] = (load_from_disk("./oursdecanometa/{}-eval.hf".format(ds_name)),load_from_disk("./oursdecanometa/{}-evalex.hf".format(ds_name)))
|
||||
|
||||
pre_tasks = []
|
||||
pre_general = []
|
||||
pre_test = []
|
||||
max_length = (
|
||||
training_args.generation_max_length
|
||||
if training_args.generation_max_length is not None
|
||||
else data_args.val_max_target_length
|
||||
)
|
||||
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||
|
||||
task_sequence = ["squad","wikisql","sst","srl","woz.en"]
|
||||
# task_sequence = ["woz.en","srl","sst","wikisql","squad"]
|
||||
need_to_do_dss = []
|
||||
fileout = open("diana_log.txt",'w')
|
||||
all_replay = None
|
||||
all_features = []
|
||||
all_ids = []
|
||||
cluster_num=0
|
||||
for cur_dataset in task_sequence:
|
||||
cluster_num+=5
|
||||
pre_tasks.append(cur_dataset)
|
||||
if cur_dataset==task_sequence[-1]:
|
||||
pre_tasks.extend(["cnn_dailymail","multinli.in.out","zre"])
|
||||
data_args.dataset_name = cur_dataset
|
||||
logger.info("current_dataset:"+cur_dataset)
|
||||
|
||||
|
||||
training_args.do_train = True
|
||||
training_args.to_eval = False
|
||||
metric = load_metric("metric/squad_v1_local/squad_v1_local.py")
|
||||
if all_replay is not None:
|
||||
fused = datasets.concatenate_datasets([all_replay,train_dataloaders[cur_dataset]])
|
||||
else:
|
||||
fused = train_dataloaders[cur_dataset]
|
||||
training_args.num_train_epochs = 5
|
||||
model.encoder.reset_train_count()
|
||||
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=fused,
|
||||
eval_dataset=None,
|
||||
eval_examples=None,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
train_result = trainer.train()
|
||||
|
||||
if training_args.local_rank<=0:
|
||||
save_set,features = save_load_diverse_sample(model,train_dataloaders[cur_dataset])
|
||||
|
||||
if all_replay is None:
|
||||
all_replay = save_set
|
||||
else:
|
||||
all_replay = datasets.concatenate_datasets([all_replay,save_set])
|
||||
|
||||
if all_features==[]:
|
||||
all_features=features
|
||||
else:
|
||||
all_features.extend(features)
|
||||
np.save("./all_features.npy",np.array(all_features))
|
||||
|
||||
all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset))
|
||||
|
||||
|
||||
|
||||
if training_args.local_rank!=-1:
|
||||
torch.distributed.barrier()
|
||||
all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset))
|
||||
all_ids.extend([task2id[cur_dataset]]*50)
|
||||
all_features=np.load("./all_features.npy").tolist()
|
||||
|
||||
model.encoder.add_negs(all_ids,all_features)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
for pre_dataset in pre_tasks:
|
||||
|
||||
|
||||
data_args.dataset_name = pre_dataset
|
||||
|
||||
metric = load_metric("metric/squad_v1_local/squad_v1_local.py")
|
||||
eval_dataset,eval_examples = eval_dataloaders[pre_dataset]
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=None,
|
||||
eval_dataset=eval_dataset,
|
||||
eval_examples=eval_examples,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("*** Evaluate:{} ***".format(data_args.dataset_name))
|
||||
|
||||
max_length, num_beams, ignore_keys_for_eval = None, None, None
|
||||
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval")
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
if training_args.local_rank<=0:
|
||||
try:
|
||||
print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout)
|
||||
print(metrics,file=fileout)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
|
||||
if len(languages) > 0:
|
||||
kwargs["language"] = languages
|
||||
return None
|
||||
|
||||
|
||||
# -
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,53 @@
|
|||
function fulldata_hfdata() {
|
||||
SAVING_PATH=$1
|
||||
PRETRAIN_MODEL_PATH=$2
|
||||
TRAIN_BATCH_SIZE=$3
|
||||
EVAL_BATCH_SIZE=$4
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
mkdir -p ${SAVING_PATH}
|
||||
|
||||
python -m torch.distributed.launch --nproc_per_node=4 dianawometa.py \
|
||||
--model_name_or_path ${PRETRAIN_MODEL_PATH} \
|
||||
--output_dir ${SAVING_PATH} \
|
||||
--dataset_name squad \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--per_device_train_batch_size ${TRAIN_BATCH_SIZE} \
|
||||
--per_device_eval_batch_size ${EVAL_BATCH_SIZE} \
|
||||
--overwrite_output_dir \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 5 \
|
||||
--warmup_ratio 0.1 \
|
||||
--logging_steps 5 \
|
||||
--learning_rate 1e-4 \
|
||||
--predict_with_generate \
|
||||
--num_beams 4 \
|
||||
--save_strategy no \
|
||||
--evaluation_strategy no \
|
||||
--weight_decay 1e-2 \
|
||||
--max_source_length 512 \
|
||||
--label_smoothing_factor 0.1 \
|
||||
--do_lowercase True \
|
||||
--load_best_model_at_end True \
|
||||
--greater_is_better True \
|
||||
--save_total_limit 10 \
|
||||
--ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log
|
||||
}
|
||||
|
||||
SAVING_PATH=$1
|
||||
TRAIN_BATCH_SIZE=$2
|
||||
EVAL_BATCH_SIZE=$3
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
MODE=try
|
||||
|
||||
|
||||
SAVING_PATH=${SAVING_PATH}/lifelong/${MODE}
|
||||
fulldata_hfdata ${SAVING_PATH} t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE}
|
|
@ -0,0 +1,801 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
import copy,random
|
||||
import sys
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from sklearn.cluster import KMeans
|
||||
# from data import QAData
|
||||
sys.path.append("../")
|
||||
from models.metadecanotask import T5ForConditionalGeneration as PromptT5
|
||||
from metrics import compute_metrics
|
||||
from downstreamdeca.simple_processors import *
|
||||
from downstreamdeca.l2ptrainer import QuestionAnsweringTrainer
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
M2M100Tokenizer,
|
||||
MBart50Tokenizer,
|
||||
EvalPrediction,
|
||||
MBart50TokenizerFast,
|
||||
MBartTokenizer,
|
||||
MBartTokenizerFast,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
default_data_collator,
|
||||
set_seed,
|
||||
)
|
||||
from pathlib import Path
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
|
||||
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
|
||||
format2id = {'extractive':0,'abstractive':1,'multichoice':2}
|
||||
format2dataset = {
|
||||
'extractive':['squad','extractive','newsqa','quoref'], #,'ropes'
|
||||
'abstractive':['narrativeqa','abstractive','nqopen','drop'],
|
||||
'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream']
|
||||
}
|
||||
seed_datasets = ['race','narrativeqa','squad']
|
||||
dataset2format= {}
|
||||
task2id = {}
|
||||
for k,vs in format2dataset.items():
|
||||
for v in vs:
|
||||
dataset2format[v] = k
|
||||
|
||||
|
||||
task2id = {"wikisql":0,"squad":1,"srl":2,"sst":3,"woz.en":4,"iwslt.en.de":5,"cnn_dailymail":6,"multinli.in.out":7,"zre":8}
|
||||
task2format={"wikisql":1,"squad":0,"srl":0,"sst":2,"woz.en":1,"iwslt.en.de":0,"cnn_dailymail":1,"multinli.in.out":2,"zre":0}
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
fix_t5: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to fix the main parameters of t5"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
source_lang: str = field(default=None, metadata={"help": "Source language id for translation."})
|
||||
target_lang: str = field(default=None, metadata={"help": "Target language id for translation."})
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."})
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
|
||||
"a jsonlines file."
|
||||
},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
after_train: Optional[str] = field(default=None)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=384,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||
"efficient on GPU but very bad for TPU."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||
"which is used during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
||||
},
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
)
|
||||
forced_bos_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
|
||||
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
|
||||
"needs to be the target language token.(Usually it is the target language token)"
|
||||
},
|
||||
)
|
||||
do_lowercase: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
'help':'Whether to process input into lowercase'
|
||||
}
|
||||
)
|
||||
append_another_bos: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
'help':'Whether to add another bos'
|
||||
}
|
||||
)
|
||||
context_column: Optional[str] = field(
|
||||
default="context",
|
||||
metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."},
|
||||
)
|
||||
question_column: Optional[str] = field(
|
||||
default="question",
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
answer_column: Optional[str] = field(
|
||||
default="answers",
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
max_task_num: Optional[int] = field(
|
||||
default=30,
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
qa_task_type_num: Optional[int] = field(
|
||||
default=4,
|
||||
metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."},
|
||||
)
|
||||
prompt_number: Optional[int] = field(
|
||||
default=100,
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
add_task_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."},
|
||||
)
|
||||
reload_from_trained_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to reload prompt from trained prompt"},
|
||||
)
|
||||
trained_prompt_path:Optional[str] = field(
|
||||
default = None,
|
||||
metadata={
|
||||
'help':'the path storing trained prompt embedding'
|
||||
}
|
||||
)
|
||||
load_from_format_task_id: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "whether to reload prompt from format-corresponding task prompt"}
|
||||
)
|
||||
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
# elif self.source_lang is None or self.target_lang is None:
|
||||
# raise ValueError("Need to specify the source language and the target language.")
|
||||
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
# assert extension == "json", "`train_file` should be a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
# assert extension == "json", "`validation_file` should be a json file."
|
||||
if self.val_max_target_length is None:
|
||||
self.val_max_target_length = self.max_target_length
|
||||
|
||||
|
||||
# +
|
||||
def main():
|
||||
|
||||
|
||||
|
||||
def preprocess_validation_function(examples):
|
||||
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
model_inputs["example_id"] = []
|
||||
|
||||
for i in range(len(model_inputs["input_ids"])):
|
||||
# One example can give several spans, this is the index of the example containing this span of text.
|
||||
sample_index = i #sample_mapping[i]
|
||||
model_inputs["example_id"].append(examples["id"][sample_index])
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
|
||||
def save_prompt_embedding(model,path):
|
||||
prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight']
|
||||
save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id}
|
||||
prompt_path = os.path.join(path,'prompt_embedding_info')
|
||||
torch.save(save_prompt_info,prompt_path)
|
||||
logger.info(f'Saving prompt embedding information to {prompt_path}')
|
||||
|
||||
|
||||
|
||||
def preprocess_function(examples):
|
||||
preprocess_fn = dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column)
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
|
||||
labels["input_ids"] = [
|
||||
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
|
||||
]
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
def save_load_diverse_sample(model,trainset):
|
||||
with torch.no_grad():
|
||||
device = 'cuda:0'
|
||||
search_upbound = len(trainset)//4
|
||||
query_idxs = [None]*30
|
||||
keys = model.encoder.domain_keys
|
||||
for idx,item in enumerate(trainset.select(range(search_upbound))):
|
||||
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
|
||||
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
|
||||
return_dict=True)
|
||||
result = torch.matmul(query,keys.t())
|
||||
result = torch.topk(result,5).indices[0].cpu().numpy().tolist()
|
||||
key_sel = None
|
||||
for key_idx in result:
|
||||
if query_idxs[key_idx] is None or len(query_idxs[key_idx])<3:
|
||||
key_sel = key_idx
|
||||
break
|
||||
if key_sel is not None:
|
||||
if query_idxs[key_sel] is None:
|
||||
query_idxs[key_sel] = [idx]
|
||||
else:
|
||||
query_idxs[key_sel].append(idx)
|
||||
total_idxs = []
|
||||
for item in query_idxs:
|
||||
try:
|
||||
total_idxs.extend(item[:3])
|
||||
except:
|
||||
total_idxs.extend(random.sample(list(range(search_upbound,len(trainset))),3))
|
||||
total_idxs = list(set(total_idxs))
|
||||
total_idxs = random.sample(total_idxs,50)
|
||||
sub_set = trainset.select(total_idxs)
|
||||
|
||||
features = []
|
||||
for idx,item in enumerate(sub_set):
|
||||
query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device),
|
||||
attention_mask=torch.tensor([item['attention_mask']]).long().to(device),
|
||||
return_dict=True)
|
||||
features.append(query.detach().cpu().numpy())
|
||||
|
||||
|
||||
|
||||
return sub_set,features
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def dataset_name_to_func(dataset_name):
|
||||
mapping = {
|
||||
'squad': preprocess_sqaud_batch,
|
||||
'squad_v2': preprocess_sqaud_batch,
|
||||
'boolq': preprocess_boolq_batch,
|
||||
'narrativeqa': preprocess_narrativeqa_batch,
|
||||
'race': preprocess_race_batch,
|
||||
'newsqa': preprocess_newsqa_batch,
|
||||
'quoref': preprocess_sqaud_batch,
|
||||
'ropes': preprocess_ropes_batch,
|
||||
'drop': preprocess_drop_batch,
|
||||
'nqopen': preprocess_sqaud_abstractive_batch,
|
||||
# 'multirc': preprocess_boolq_batch,
|
||||
'boolq_np': preprocess_boolq_batch,
|
||||
'openbookqa': preprocess_openbookqa_batch,
|
||||
'mctest': preprocess_race_batch,
|
||||
'social_iqa': preprocess_social_iqa_batch,
|
||||
'dream': preprocess_dream_batch,
|
||||
}
|
||||
return mapping[dataset_name]
|
||||
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
dataset_name_to_metric = {
|
||||
'squad': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'squad_v2': 'metric/squad_v2_local/squad_v2_local.py',
|
||||
'newsqa': 'metric/squad_v2_local/squad_v2_local.py',
|
||||
'boolq': 'metric/accuracy.py',
|
||||
'narrativeqa': 'metric/rouge_local/rouge_metric.py',
|
||||
'race': 'metric/accuracy.py',
|
||||
'quoref': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'ropes': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'drop': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'nqopen': 'metric/squad_v1_local/squad_v1_local.py',
|
||||
'boolq_np': 'metric/accuracy.py',
|
||||
'openbookqa': 'metric/accuracy.py',
|
||||
'mctest': 'metric/accuracy.py',
|
||||
'social_iqa': 'metric/accuracy.py',
|
||||
'dream': 'metric/accuracy.py',
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
if data_args.source_prefix is None and model_args.model_name_or_path in [
|
||||
"t5-small",
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
"t5-3b",
|
||||
"t5-11b",
|
||||
]:
|
||||
logger.warning(
|
||||
"You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
|
||||
"`--source_prefix 'translate English to German: ' `"
|
||||
)
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
# tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]',
|
||||
# '[OPTIONS]'])
|
||||
tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]']
|
||||
special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]',
|
||||
'[OPTIONS]']}
|
||||
|
||||
tokenizer.add_tokens(tokens_to_add)
|
||||
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
added_tokens = tokenizer.get_added_vocab()
|
||||
logger.info('Added tokens: {}'.format(added_tokens))
|
||||
|
||||
model = PromptT5.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
# task_num = data_args.max_task_num,
|
||||
# prompt_num = data_args.prompt_number,
|
||||
# format_num = data_args.qa_task_type_num,
|
||||
# add_task_prompt = False
|
||||
)
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
|
||||
#reload format specific task-prompt for newly involved task
|
||||
#format_prompts###task_promptsf
|
||||
|
||||
data_args.reload_from_trained_prompt = False#@
|
||||
data_args.load_from_format_task_id = False#@
|
||||
### before pretrain come !!!!!!
|
||||
|
||||
if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt:
|
||||
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
||||
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
||||
format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:]
|
||||
logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}')
|
||||
# print(dataset2format[data_args.dataset_name])
|
||||
# print(data_args.dataset_name)
|
||||
elif data_args.reload_from_trained_prompt:
|
||||
assert data_args.trained_prompt_path,'Must specify the path of stored prompt'
|
||||
prompt_info = torch.load(data_args.trained_prompt_path)
|
||||
assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}'
|
||||
assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match'
|
||||
task_start_id = data_args.prompt_number * len(format2dataset.keys())
|
||||
|
||||
task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number
|
||||
logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number))
|
||||
# assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:]
|
||||
format_id = format2id[dataset2format[data_args.dataset_name]]
|
||||
model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :]
|
||||
logger.info(
|
||||
f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}')
|
||||
|
||||
# Set decoder_start_token_id
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||
if isinstance(tokenizer, MBartTokenizer):
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
|
||||
else:
|
||||
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
|
||||
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
|
||||
|
||||
if training_args.local_rank == -1 or training_args.no_cuda:
|
||||
device = torch.device("cuda")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Temporarily set max_target_length for training.
|
||||
max_target_length = data_args.max_target_length
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||
logger.warning(
|
||||
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
|
||||
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
|
||||
)
|
||||
|
||||
|
||||
question_column = data_args.question_column
|
||||
context_column = data_args.context_column
|
||||
answer_column = data_args.answer_column
|
||||
# import random
|
||||
if data_args.max_source_length > tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length)
|
||||
# Data collator
|
||||
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
if data_args.pad_to_max_length:
|
||||
data_collator = default_data_collator
|
||||
else:
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
model=model,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
)
|
||||
|
||||
#start
|
||||
train_dataloaders = {}
|
||||
eval_dataloaders = {}
|
||||
replay_dataloaders = {}
|
||||
all_replay = None
|
||||
|
||||
for ds_name in ["squad","wikisql","sst","srl","woz.en"]:
|
||||
|
||||
if True:
|
||||
cur_dataset = ds_name
|
||||
train_dataloaders[ds_name] = load_from_disk("./oursdecanotask/{}-train.hf".format(cur_dataset)).select(range(200))
|
||||
eval_dataloaders[ds_name] = (load_from_disk("./oursdecanotask/{}-eval.hf".format(cur_dataset)).select(range(200)),load_from_disk("./oursdecanotask/{}-evalex.hf".format(cur_dataset)).select(range(200)))
|
||||
|
||||
for ds_name in ["cnn_dailymail","multinli.in.out","zre"]:
|
||||
eval_dataloaders[ds_name] = (load_from_disk("./oursdecanotask/{}-eval.hf".format(ds_name)),load_from_disk("./oursdecanotask/{}-evalex.hf".format(ds_name)))
|
||||
|
||||
pre_tasks = []
|
||||
pre_general = []
|
||||
pre_test = []
|
||||
max_length = (
|
||||
training_args.generation_max_length
|
||||
if training_args.generation_max_length is not None
|
||||
else data_args.val_max_target_length
|
||||
)
|
||||
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||
|
||||
task_sequence = ["squad","wikisql","sst","srl","woz.en"]
|
||||
# task_sequence = ["woz.en","srl","sst","wikisql","squad"]
|
||||
need_to_do_dss = []
|
||||
fileout = open("diana_log.txt",'w')
|
||||
all_replay = None
|
||||
all_features = []
|
||||
all_ids = []
|
||||
cluster_num=0
|
||||
for cur_dataset in task_sequence:
|
||||
cluster_num+=5
|
||||
pre_tasks.append(cur_dataset)
|
||||
if cur_dataset==task_sequence[-1]:
|
||||
pre_tasks.extend(["cnn_dailymail","multinli.in.out","zre"])
|
||||
data_args.dataset_name = cur_dataset
|
||||
logger.info("current_dataset:"+cur_dataset)
|
||||
|
||||
|
||||
training_args.do_train = True
|
||||
training_args.to_eval = False
|
||||
metric = load_metric("metric/squad_v1_local/squad_v1_local.py")
|
||||
if all_replay is not None:
|
||||
fused = datasets.concatenate_datasets([all_replay,train_dataloaders[cur_dataset]])
|
||||
else:
|
||||
fused = train_dataloaders[cur_dataset]
|
||||
training_args.num_train_epochs = 5
|
||||
model.encoder.reset_train_count()
|
||||
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=fused,
|
||||
eval_dataset=None,
|
||||
eval_examples=None,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
train_result = trainer.train()
|
||||
|
||||
if training_args.local_rank<=0:
|
||||
save_set,features = save_load_diverse_sample(model,train_dataloaders[cur_dataset])
|
||||
|
||||
if all_replay is None:
|
||||
all_replay = save_set
|
||||
else:
|
||||
all_replay = datasets.concatenate_datasets([all_replay,save_set])
|
||||
|
||||
if all_features==[]:
|
||||
all_features=features
|
||||
else:
|
||||
all_features.extend(features)
|
||||
np.save("./all_features.npy",np.array(all_features))
|
||||
|
||||
all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset))
|
||||
|
||||
|
||||
|
||||
if training_args.local_rank!=-1:
|
||||
torch.distributed.barrier()
|
||||
all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset))
|
||||
all_ids.extend([task2id[cur_dataset]]*50)
|
||||
all_features=np.load("./all_features.npy").tolist()
|
||||
|
||||
model.encoder.add_negs(all_ids,all_features)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
for pre_dataset in pre_tasks:
|
||||
|
||||
|
||||
data_args.dataset_name = pre_dataset
|
||||
|
||||
metric = load_metric("metric/squad_v1_local/squad_v1_local.py")
|
||||
eval_dataset,eval_examples = eval_dataloaders[pre_dataset]
|
||||
trainer = QuestionAnsweringTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=None,
|
||||
eval_dataset=eval_dataset,
|
||||
eval_examples=eval_examples,
|
||||
answer_column_name=answer_column,
|
||||
dataset_name=data_args.dataset_name,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("*** Evaluate:{} ***".format(data_args.dataset_name))
|
||||
|
||||
max_length, num_beams, ignore_keys_for_eval = None, None, None
|
||||
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval")
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
if training_args.local_rank<=0:
|
||||
try:
|
||||
print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout)
|
||||
print(metrics,file=fileout)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None]
|
||||
if len(languages) > 0:
|
||||
kwargs["language"] = languages
|
||||
return None
|
||||
|
||||
|
||||
# -
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,53 @@
|
|||
function fulldata_hfdata() {
|
||||
SAVING_PATH=$1
|
||||
PRETRAIN_MODEL_PATH=$2
|
||||
TRAIN_BATCH_SIZE=$3
|
||||
EVAL_BATCH_SIZE=$4
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
mkdir -p ${SAVING_PATH}
|
||||
|
||||
python -m torch.distributed.launch --nproc_per_node=4 dianawotask.py \
|
||||
--model_name_or_path ${PRETRAIN_MODEL_PATH} \
|
||||
--output_dir ${SAVING_PATH} \
|
||||
--dataset_name squad \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--per_device_train_batch_size ${TRAIN_BATCH_SIZE} \
|
||||
--per_device_eval_batch_size ${EVAL_BATCH_SIZE} \
|
||||
--overwrite_output_dir \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 5 \
|
||||
--warmup_ratio 0.1 \
|
||||
--logging_steps 5 \
|
||||
--learning_rate 1e-4 \
|
||||
--predict_with_generate \
|
||||
--num_beams 4 \
|
||||
--save_strategy no \
|
||||
--evaluation_strategy no \
|
||||
--weight_decay 1e-2 \
|
||||
--max_source_length 512 \
|
||||
--label_smoothing_factor 0.1 \
|
||||
--do_lowercase True \
|
||||
--load_best_model_at_end True \
|
||||
--greater_is_better True \
|
||||
--save_total_limit 10 \
|
||||
--ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log
|
||||
}
|
||||
|
||||
SAVING_PATH=$1
|
||||
TRAIN_BATCH_SIZE=$2
|
||||
EVAL_BATCH_SIZE=$3
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
MODE=try
|
||||
|
||||
|
||||
SAVING_PATH=${SAVING_PATH}/lifelong/${MODE}
|
||||
fulldata_hfdata ${SAVING_PATH} t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE}
|
|
@ -0,0 +1,169 @@
|
|||
import itertools
|
||||
import json
|
||||
import os
|
||||
import csv
|
||||
import errno
|
||||
import random
|
||||
from random import shuffle
|
||||
from typing import List
|
||||
import codecs
|
||||
import nltk
|
||||
import glob
|
||||
import xml.etree.ElementTree as ET
|
||||
from datasets import load_dataset
|
||||
|
||||
#nltk.download("stopwords")
|
||||
from nltk.corpus import stopwords
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def preprocess_function(examples):
|
||||
preprocess_fn = preprocess_all#dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, "input","output")
|
||||
model_inputs = tokenizer(inputs, max_length=1024, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=128, padding=padding, truncation=True)
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_function_test(examples):
|
||||
preprocess_fn = preprocess_all#dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, "input","output")
|
||||
model_inputs = tokenizer(inputs, max_length=1024, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=128, padding=padding, truncation=True)
|
||||
model_inputs["example_id"] = []
|
||||
model_inputs["id"] = []
|
||||
for i in range(len(model_inputs["input_ids"])):
|
||||
# One example can give several spans, this is the index of the example containing this span of text.
|
||||
sample_index = i #sample_mapping[i]
|
||||
model_inputs["example_id"].append(i)
|
||||
model_inputs["id"].append(i)
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
|
||||
return model_inputs
|
||||
|
||||
def add_id(example,index):
|
||||
example.update({'id':index})
|
||||
return example
|
||||
|
||||
def prep(raw_ds,fname):
|
||||
ds = []
|
||||
dss = map(lambda x: x["paragraphs"], raw_ds["data"])
|
||||
for dd in dss:
|
||||
ds.extend(dd)
|
||||
print(len(ds))
|
||||
print(len(raw_ds["data"]))
|
||||
|
||||
examples = {"context":[],"question":[],"answer":[]}
|
||||
fout = open("{}-train.jsonl".format(fname),'w')
|
||||
for d in ds:
|
||||
|
||||
context = d["context"]
|
||||
#TOKENIZER.encode(d["context"])
|
||||
for qa in d["qas"]:
|
||||
question = qa["question"]#TOKENIZER.encode(qa["question"])
|
||||
|
||||
raw_answers = qa["answers"]
|
||||
if len(raw_answers) == 0:
|
||||
assert qa["is_impossible"]
|
||||
raw_answers.append({"text": ""})
|
||||
|
||||
answers = []
|
||||
|
||||
for i, raw_answer in enumerate(raw_answers):
|
||||
answers.append(raw_answer["text"])
|
||||
jsonline = json.dumps({"question":question,"context":context,"answer":answers})
|
||||
print(jsonline,file=fout)
|
||||
fout.close()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def prep_test(raw_ds,use_answers,fname):
|
||||
ds = []
|
||||
dss = map(lambda x: x["paragraphs"], raw_ds["data"])
|
||||
for dd in dss:
|
||||
ds.extend(dd)
|
||||
print(len(ds))
|
||||
print(len(raw_ds["data"]))
|
||||
fout = open("{}-eval.jsonl".format(fname),'w')
|
||||
idx = 0
|
||||
f_answers = []
|
||||
use_answers = None
|
||||
all_json_lines = []
|
||||
for d in ds:
|
||||
|
||||
context = d["context"]
|
||||
#TOKENIZER.encode(d["context"])
|
||||
for qa in d["qas"]:
|
||||
question = qa["question"]#TOKENIZER.encode(qa["question"])
|
||||
raw_answers = qa["answers"]
|
||||
f_answers.extend([_["text"] for _ in qa["answers"]])
|
||||
if True:
|
||||
if len(raw_answers) == 0:
|
||||
assert qa["is_impossible"]
|
||||
raw_answers.append({"text": ""})
|
||||
|
||||
answers = []
|
||||
|
||||
for i, raw_answer in enumerate(raw_answers):
|
||||
answers.append(raw_answer["text"])
|
||||
all_json_lines.append({"question":question,"context":context,"answer":answers,"preid":qa["id"]})
|
||||
if fname in ["wikisql","woz.en","multinli.in.out"]:
|
||||
all_json_lines.sort(key=lambda x: x["preid"])
|
||||
|
||||
for item in all_json_lines:
|
||||
jsonline = json.dumps(item)
|
||||
print(jsonline,file=fout)
|
||||
|
||||
fout.close()
|
||||
|
||||
|
||||
from collections import Counter
|
||||
import json
|
||||
|
||||
|
||||
def main():
|
||||
for item in ["wikisql","squad","srl","sst","woz.en","iwslt.en.de","cnn_dailymail","multinli.in.out","zre"]:
|
||||
print("current_dataset:",item)
|
||||
print("Loading......")
|
||||
|
||||
testfile = open(item+"_to_squad-test-v2.0.json",'r')
|
||||
if not item in ["cnn_dailymail","multinli.in.out","zre"]:
|
||||
trainfile = open(item+"_to_squad-train-v2.0.json",'r')
|
||||
ftrain = json.load(trainfile)
|
||||
ftest = json.load(testfile)
|
||||
faddition = None
|
||||
if item in ["woz.en","wikisql"]:
|
||||
addition_answers = open(item+"_answers.json",'r')
|
||||
faddition = json.load(addition_answers)
|
||||
if item=="woz.en":
|
||||
faddition = [[_[1]] for _ in faddition]
|
||||
else:
|
||||
faddition = [[_["answer"]] for _ in faddition]
|
||||
else:
|
||||
faddition = None
|
||||
print("Loaded.")
|
||||
if not item in ["cnn_dailymail","multinli.in.out","zre"]:
|
||||
prep(ftrain,item)
|
||||
|
||||
prep_test(ftest,faddition,item)
|
||||
|
||||
|
||||
|
||||
|
||||
main()
|
|
@ -0,0 +1,557 @@
|
|||
from transformers import Seq2SeqTrainer, is_torch_tpu_available, EvalPrediction
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import nltk
|
||||
import datasets
|
||||
import re
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
from pathlib import Path
|
||||
import nltk
|
||||
# nltk.download('punkt')
|
||||
from transformers.trainer_utils import (
|
||||
PREFIX_CHECKPOINT_DIR,
|
||||
BestRun,
|
||||
EvalLoopOutput,
|
||||
EvalPrediction,
|
||||
HPSearchBackend,
|
||||
HubStrategy,
|
||||
IntervalStrategy,
|
||||
PredictionOutput,
|
||||
ShardedDDPOption,
|
||||
TrainerMemoryTracker,
|
||||
TrainOutput,
|
||||
default_compute_objective,
|
||||
default_hp_space,
|
||||
denumpify_detensorize,
|
||||
get_last_checkpoint,
|
||||
number_of_arguments,
|
||||
set_seed,
|
||||
speed_metrics,
|
||||
)
|
||||
import warnings
|
||||
from transformers.trainer_pt_utils import (
|
||||
DistributedLengthGroupedSampler,
|
||||
DistributedSamplerWithLoop,
|
||||
DistributedTensorGatherer,
|
||||
IterableDatasetShard,
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
SequentialDistributedSampler,
|
||||
ShardSampler,
|
||||
distributed_broadcast_scalars,
|
||||
distributed_concat,
|
||||
find_batch_size,
|
||||
get_parameter_names,
|
||||
nested_concat,
|
||||
nested_detach,
|
||||
nested_numpify,
|
||||
nested_truncate,
|
||||
nested_xla_mesh_reduce,
|
||||
reissue_pt_warnings,
|
||||
)
|
||||
from transformers.file_utils import (
|
||||
CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
get_full_repo_name,
|
||||
is_apex_available,
|
||||
is_datasets_available,
|
||||
is_in_notebook,
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_torch_tpu_available,
|
||||
)
|
||||
TRAINING_ARGS_NAME = "training_args.bin"
|
||||
TRAINER_STATE_NAME = "trainer_state.json"
|
||||
OPTIMIZER_NAME = "optimizer.pt"
|
||||
SCHEDULER_NAME = "scheduler.pt"
|
||||
SCALER_NAME = "scaler.pt"
|
||||
from transformers.trainer_utils import PredictionOutput,EvalLoopOutput
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.metrics as met
|
||||
|
||||
def fix_buggy_characters(str):
|
||||
return re.sub("[{}^\\\\`\u2047<]", " ", str)
|
||||
def replace_punctuation(str):
|
||||
return str.replace("\"", "").replace("'", "")
|
||||
def score_string_similarity(str1, str2):
|
||||
if str1 == str2:
|
||||
return 3.0 # Better than perfect token match
|
||||
str1 = fix_buggy_characters(replace_punctuation(str1))
|
||||
str2 = fix_buggy_characters(replace_punctuation(str2))
|
||||
if str1 == str2:
|
||||
return 2.0
|
||||
if " " in str1 or " " in str2:
|
||||
str1_split = str1.split(" ")
|
||||
str2_split = str2.split(" ")
|
||||
overlap = list(set(str1_split) & set(str2_split))
|
||||
return len(overlap) / max(len(str1_split), len(str2_split))
|
||||
else:
|
||||
if str1 == str2:
|
||||
return 1.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
class QuestionAnsweringTrainer(Seq2SeqTrainer):
|
||||
def __init__(self, *args, tokenizer,eval_examples=None,answer_column_name='answers',dataset_name='squad', **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.eval_examples = eval_examples
|
||||
self.answer_column_name = answer_column_name
|
||||
self.dataset_name = dataset_name
|
||||
self.tokenizer = tokenizer
|
||||
if "cnn" in dataset_name:
|
||||
self.post_process_function = self._post_process_narrative_qa
|
||||
else:
|
||||
self.post_process_function = self._post_process_squad
|
||||
|
||||
|
||||
def _post_process_squad(
|
||||
self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval"
|
||||
,version_2_with_negative=False):
|
||||
# Decode the predicted tokens.
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
|
||||
# Build a map example to its corresponding features.
|
||||
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
|
||||
feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)}
|
||||
predictions = {}
|
||||
# Let's loop over all the examples!
|
||||
for example_index, example in enumerate(examples):
|
||||
# This is the index of the feature associated to the current example.
|
||||
try:
|
||||
feature_index = feature_per_example[example_index]
|
||||
except:
|
||||
print(feature_per_example.keys())
|
||||
print(example_index)
|
||||
assert False
|
||||
predictions[example["id"]] = decoded_preds[feature_index]
|
||||
|
||||
# Format the result to the format the metric expects.
|
||||
if version_2_with_negative:
|
||||
formatted_predictions = [
|
||||
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
|
||||
]
|
||||
else:
|
||||
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
|
||||
|
||||
references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples]
|
||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||
|
||||
def _post_process_squad_v2(
|
||||
self, examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval"):
|
||||
# Decode the predicted tokens.
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
|
||||
# Build a map example to its corresponding features.
|
||||
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
|
||||
feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)}
|
||||
predictions = {}
|
||||
# Let's loop over all the examples!
|
||||
for example_index, example in enumerate(examples):
|
||||
# This is the index of the feature associated to the current example.
|
||||
feature_index = feature_per_example[example_index]
|
||||
predictions[example["id"]] = decoded_preds[feature_index]
|
||||
|
||||
# Format the result to the format the metric expects.
|
||||
formatted_predictions = [
|
||||
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
|
||||
]
|
||||
|
||||
references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples]
|
||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||
|
||||
@classmethod
|
||||
def postprocess_text(cls,preds, labels):
|
||||
preds = [pred.strip() for pred in preds]
|
||||
labels = [label.strip() for label in labels]
|
||||
# rougeLSum expects newline after each sentence
|
||||
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
||||
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
||||
return preds, labels
|
||||
|
||||
def _post_process_boolq(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
# preds = [" ".join(pred) for pred in decoded_preds]
|
||||
preds = decoded_preds
|
||||
outputs = []
|
||||
# for pred in preds:
|
||||
# if 'True' == pred:
|
||||
# outputs.append(1)
|
||||
# elif 'False' == pred:
|
||||
# outputs.append(0)
|
||||
# else:
|
||||
# outputs.append(0)
|
||||
# references = [1 if ex['answer'] else 0 for ex in examples]
|
||||
|
||||
references = []
|
||||
for pred, ref in zip(preds, examples):
|
||||
ans = ref['answer']
|
||||
references.append(1 if ans else 0)
|
||||
if pred.lower() in ['true', 'false']:
|
||||
outputs.append(1 if pred.lower() == 'true' else 0)
|
||||
else:
|
||||
# generate a wrong prediction if pred is not true/false
|
||||
outputs.append(0 if ans else 1)
|
||||
|
||||
assert(len(references)==len(outputs))
|
||||
formatted_predictions = []
|
||||
return EvalPrediction(predictions=outputs, label_ids=references)
|
||||
|
||||
def _post_process_narrative_qa(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
# preds = [" ".join(pred) for pred in decoded_preds]
|
||||
preds = decoded_preds
|
||||
references = [exp['answers']['text'][0] for exp in examples]
|
||||
formatted_predictions = []
|
||||
preds, references = self.postprocess_text(preds,references)
|
||||
assert(len(preds)==len(references))
|
||||
return EvalPrediction(predictions=preds, label_ids=references)
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
if self.label_smoother is not None and "labels" in inputs:
|
||||
labels = inputs.pop("labels")
|
||||
inputs["train"]=True
|
||||
else:
|
||||
labels = None
|
||||
outputs,sim_loss = model(**inputs)
|
||||
|
||||
# print(model.encoder.)
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index]
|
||||
|
||||
if labels is not None:
|
||||
loss = self.label_smoother(outputs, labels)
|
||||
else:
|
||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
|
||||
|
||||
loss += sim_loss*0.1
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
|
||||
def _post_process_drop(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
# preds = [" ".join(pred) for pred in decoded_preds]
|
||||
preds = decoded_preds
|
||||
references = [exp['answers'][0]['text'] for exp in examples]
|
||||
formatted_predictions = []
|
||||
preds, references = self.postprocess_text(preds, references)
|
||||
assert (len(preds) == len(references))
|
||||
return EvalPrediction(predictions=preds, label_ids=references)
|
||||
|
||||
def _post_process_race(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
references = [{"answer": ex['answer'], 'options':ex['options']} for ex in examples]
|
||||
assert(len(references)==len(decoded_preds))
|
||||
gold_ids , pred_ids = [],[]
|
||||
for prediction, reference in zip(decoded_preds, references):
|
||||
# reference = json.loads(reference)
|
||||
gold = int(ord(reference['answer'].strip()) - ord('A'))
|
||||
options = reference['options']
|
||||
prediction = prediction.replace("\n", "").strip()
|
||||
options = [opt.strip() for opt in options if len(opt) > 0]
|
||||
# print('options',options,type(options))
|
||||
# print('prediction',prediction)
|
||||
# print('answer',gold)
|
||||
scores = [score_string_similarity(opt, prediction) for opt in options]
|
||||
max_idx = np.argmax(scores)
|
||||
gold_ids.append(gold)
|
||||
pred_ids.append(max_idx)
|
||||
# selected_ans = chr(ord('A') + max_idx)
|
||||
# print(len(references),len(decoded_preds))
|
||||
|
||||
return EvalPrediction(predictions=pred_ids,label_ids = gold_ids)
|
||||
|
||||
def _post_process_openbookqa(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
references = [{"answer": ex['answerKey'], 'options': ex['choices']['text']} for ex in examples]
|
||||
assert len(references) == len(decoded_preds)
|
||||
gold_ids, pred_ids = [], []
|
||||
for prediction, reference in zip(decoded_preds, references):
|
||||
# reference = json.loads(reference)
|
||||
gold = int(ord(reference['answer'].strip()) - ord('A'))
|
||||
options = reference['options']
|
||||
prediction = prediction.replace("\n", "").strip()
|
||||
options = [opt.strip() for opt in options if len(opt) > 0]
|
||||
scores = [score_string_similarity(opt, prediction) for opt in options]
|
||||
max_idx = np.argmax(scores)
|
||||
gold_ids.append(gold)
|
||||
pred_ids.append(max_idx)
|
||||
return EvalPrediction(predictions=pred_ids,label_ids = gold_ids)
|
||||
|
||||
def _post_process_dream(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"):
|
||||
preds = outputs.predictions
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
references = [{"answer": ex['answer'], 'options': ex['choice']} for ex in examples]
|
||||
assert(len(references)==len(decoded_preds))
|
||||
gold_ids, pred_ids = [],[]
|
||||
for prediction, reference in zip(decoded_preds, references):
|
||||
gold = reference['options'].index(reference['answer'])
|
||||
options = reference['options']
|
||||
prediction = prediction.replace("\n", "").strip()
|
||||
options = [opt.strip() for opt in options if len(opt) > 0]
|
||||
scores = [score_string_similarity(opt, prediction) for opt in options]
|
||||
max_idx = np.argmax(scores)
|
||||
gold_ids.append(gold)
|
||||
pred_ids.append(max_idx)
|
||||
return EvalPrediction(predictions=pred_ids, label_ids = gold_ids)
|
||||
|
||||
def evaluate(self, eval_dataset=None, eval_examples=None, tokenizer=None,ignore_keys=None, metric_key_prefix: str = "eval",
|
||||
max_length: Optional[int] = None,num_beams: Optional[int] = None):
|
||||
gen_kwargs = {}
|
||||
gen_kwargs["max_length"] = (
|
||||
max_length
|
||||
)
|
||||
gen_kwargs["num_beams"] = (
|
||||
num_beams
|
||||
)
|
||||
self._gen_kwargs = gen_kwargs
|
||||
self._memory_tracker.start()
|
||||
self._max_length = max_length if max_length is not None else self.args.generation_max_length
|
||||
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
|
||||
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
|
||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||
eval_examples = self.eval_examples if eval_examples is None else eval_examples
|
||||
|
||||
# Temporarily disable metric computation, we will do it in the loop here.
|
||||
compute_metrics = self.compute_metrics
|
||||
self.compute_metrics = None
|
||||
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
|
||||
try:
|
||||
output = eval_loop(
|
||||
eval_dataloader,
|
||||
description="Evaluation",
|
||||
# No point gathering the predictions if there are no metrics, otherwise we defer to
|
||||
# self.args.prediction_loss_only
|
||||
prediction_loss_only=True if compute_metrics is None else None,
|
||||
ignore_keys=ignore_keys,
|
||||
)
|
||||
finally:
|
||||
self.compute_metrics = compute_metrics
|
||||
|
||||
if self.post_process_function is not None and self.compute_metrics is not None:
|
||||
eval_preds = self.post_process_function(eval_examples, eval_dataset, output,tokenizer)
|
||||
metrics = self.compute_metrics(eval_preds,
|
||||
tri=self.dataset_name=="multinli.in.out",
|
||||
dialogue=self.dataset_name=='woz.en',
|
||||
rouge=self.dataset_name=='cnn_dailymail',
|
||||
logical_form=self.dataset_name=='wikisql',
|
||||
corpus_f1=self.dataset_name=='zre')
|
||||
if self.dataset_name=='narrativeqa':
|
||||
metrics = {key: value.mid.fmeasure * 100 for key, value in metrics.items()}
|
||||
metrics = {k: round(v, 4) for k, v in metrics.items()}
|
||||
# Prefix all keys with metric_key_prefix + '_'
|
||||
for key in list(metrics.keys()):
|
||||
if not key.startswith(f"{metric_key_prefix}_"):
|
||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||
|
||||
print(metrics)
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||
self._memory_tracker.stop_and_update_metrics(metrics)
|
||||
return metrics
|
||||
|
||||
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||
# want to save except FullyShardedDDP.
|
||||
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
|
||||
|
||||
# Save model checkpoint
|
||||
save_only_best = True
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
|
||||
if self.hp_search_backend is not None and trial is not None:
|
||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||
run_id = trial.number
|
||||
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||
from ray import tune
|
||||
|
||||
run_id = tune.get_trial_id()
|
||||
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
|
||||
run_id = trial.id
|
||||
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
|
||||
run_dir = os.path.join(self.args.output_dir, run_name)
|
||||
else:
|
||||
run_dir = self.args.output_dir
|
||||
self.store_flos()
|
||||
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
if not save_only_best:
|
||||
self.save_model(output_dir)
|
||||
if self.deepspeed:
|
||||
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
|
||||
# config `stage3_gather_fp16_weights_on_model_save` is True
|
||||
self.deepspeed.save_checkpoint(output_dir)
|
||||
|
||||
# Save optimizer and scheduler
|
||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||
self.optimizer.consolidate_state_dict()
|
||||
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous("saving_optimizer_states")
|
||||
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
elif is_sagemaker_mp_enabled():
|
||||
if smp.dp_rank() == 0:
|
||||
# Consolidate the state dict on all processed of dp_rank 0
|
||||
opt_state_dict = self.optimizer.state_dict()
|
||||
# Save it and the scheduler on the main process
|
||||
if self.args.should_save and not save_only_best:
|
||||
torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.do_grad_scaling:
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
elif self.args.should_save and not self.deepspeed and not save_only_best:
|
||||
# deepspeed.save_checkpoint above saves model/optim/sched
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
self.do_grad_scaling=True
|
||||
# if self.do_grad_scaling:
|
||||
# torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
|
||||
# Determine the new best metric / best model checkpoint
|
||||
if metrics is not None and self.args.metric_for_best_model is not None:
|
||||
metric_to_check = self.args.metric_for_best_model
|
||||
if not metric_to_check.startswith("eval_"):
|
||||
metric_to_check = f"eval_{metric_to_check}"
|
||||
metric_value = metrics[metric_to_check]
|
||||
|
||||
operator = np.greater if self.args.greater_is_better else np.less
|
||||
if (
|
||||
self.state.best_metric is None
|
||||
or self.state.best_model_checkpoint is None
|
||||
or operator(metric_value, self.state.best_metric)
|
||||
):
|
||||
self.state.best_metric = metric_value
|
||||
self.state.best_model_checkpoint = output_dir
|
||||
best_dir = os.path.join(run_dir,'best-checkpoint')
|
||||
# if not os.path.exists(best_dir):
|
||||
# os.mkdir(best_dir,exist_ok=True)
|
||||
if self.args.should_save:
|
||||
self.save_model(best_dir)
|
||||
self.state.save_to_json(os.path.join(best_dir, TRAINER_STATE_NAME))
|
||||
|
||||
|
||||
# Save the Trainer state
|
||||
if self.args.should_save and not save_only_best:
|
||||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||
|
||||
# Save RNG state in non-distributed training
|
||||
rng_states = {
|
||||
"python": random.getstate(),
|
||||
"numpy": np.random.get_state(),
|
||||
"cpu": torch.random.get_rng_state(),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
if self.args.local_rank == -1:
|
||||
# In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
|
||||
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
|
||||
else:
|
||||
rng_states["cuda"] = torch.cuda.random.get_rng_state()
|
||||
|
||||
if is_torch_tpu_available():
|
||||
rng_states["xla"] = xm.get_rng_state()
|
||||
|
||||
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
|
||||
# not yet exist.
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
|
||||
if local_rank == -1:
|
||||
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
|
||||
else:
|
||||
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth"))
|
||||
|
||||
if self.args.push_to_hub:
|
||||
self._push_from_checkpoint(output_dir)
|
||||
|
||||
# Maybe delete some older checkpoints.
|
||||
if self.args.should_save and not save_only_best:
|
||||
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
||||
|
||||
|
||||
|
||||
|
||||
# def predict(self, predict_dataset=None, predict_examples=None, tokenizer=None,ignore_keys=None, metric_key_prefix: str = "predict",
|
||||
# max_length: Optional[int] = None,num_beams: Optional[int] = None):
|
||||
# self._memory_tracker.start()
|
||||
# self._max_length = max_length if max_length is not None else self.args.generation_max_length
|
||||
# self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
|
||||
# predict_dataset = self.predict_dataset if predict_dataset is None else predict_dataset
|
||||
# predict_dataloader = self.get_eval_dataloader(predict_dataset)
|
||||
# predict_examples = predict_examples
|
||||
#
|
||||
# # Temporarily disable metric computation, we will do it in the loop here.
|
||||
# compute_metrics = self.compute_metrics
|
||||
# self.compute_metrics = None
|
||||
# eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
|
||||
# try:
|
||||
# output = eval_loop(
|
||||
# predict_dataloader,
|
||||
# description="Predict",
|
||||
# # No point gathering the predictions if there are no metrics, otherwise we defer to
|
||||
# # self.args.prediction_loss_only
|
||||
# prediction_loss_only=True if compute_metrics is None else None,
|
||||
# ignore_keys=ignore_keys,
|
||||
# )
|
||||
# finally:
|
||||
# self.compute_metrics = compute_metrics
|
||||
#
|
||||
# if self.post_process_function is not None and self.compute_metrics is not None:
|
||||
# predict_preds = self.post_process_function(predict_examples, predict_dataset, output,tokenizer)
|
||||
# metrics = self.compute_metrics(predict_preds)
|
||||
# if self.dataset_name=='narrativeqa':
|
||||
# metrics = {key: value.mid.fmeasure * 100 for key, value in metrics.items()}
|
||||
# metrics = {k: round(v, 4) for k, v in metrics.items()}
|
||||
# # Prefix all keys with metric_key_prefix + '_'
|
||||
# for key in list(metrics.keys()):
|
||||
# if not key.startswith(f"{metric_key_prefix}_"):
|
||||
# metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||
#
|
||||
# self.log(metrics)
|
||||
# else:
|
||||
# metrics = {}
|
||||
#
|
||||
# self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||
# self._memory_tracker.stop_and_update_metrics(metrics)
|
||||
# return metrics
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Accuracy metric."""
|
||||
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
_DESCRIPTION = """
|
||||
Accuracy is the proportion of correct predictions among the total number of cases processed. It can be computed with:
|
||||
Accuracy = (TP + TN) / (TP + TN + FP + FN)
|
||||
Where:
|
||||
TP: True positive
|
||||
TN: True negative
|
||||
FP: False positive
|
||||
FN: False negative
|
||||
"""
|
||||
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Args:
|
||||
predictions (`list` of `int`): Predicted labels.
|
||||
references (`list` of `int`): Ground truth labels.
|
||||
normalize (`boolean`): If set to False, returns the number of correctly classified samples. Otherwise, returns the fraction of correctly classified samples. Defaults to True.
|
||||
sample_weight (`list` of `float`): Sample weights Defaults to None.
|
||||
|
||||
Returns:
|
||||
accuracy (`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher score means higher accuracy.
|
||||
|
||||
Examples:
|
||||
|
||||
Example 1-A simple example
|
||||
>>> accuracy_metric = datasets.load_metric("accuracy")
|
||||
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0])
|
||||
>>> print(results)
|
||||
{'accuracy': 0.5}
|
||||
|
||||
Example 2-The same as Example 1, except with `normalize` set to `False`.
|
||||
>>> accuracy_metric = datasets.load_metric("accuracy")
|
||||
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], normalize=False)
|
||||
>>> print(results)
|
||||
{'accuracy': 3.0}
|
||||
|
||||
Example 3-The same as Example 1, except with `sample_weight` set.
|
||||
>>> accuracy_metric = datasets.load_metric("accuracy")
|
||||
>>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], sample_weight=[0.5, 2, 0.7, 0.5, 9, 0.4])
|
||||
>>> print(results)
|
||||
{'accuracy': 0.8778625954198473}
|
||||
"""
|
||||
|
||||
|
||||
_CITATION = """
|
||||
@article{scikit-learn,
|
||||
title={Scikit-learn: Machine Learning in {P}ython},
|
||||
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
|
||||
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
|
||||
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
|
||||
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
|
||||
journal={Journal of Machine Learning Research},
|
||||
volume={12},
|
||||
pages={2825--2830},
|
||||
year={2011}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class Accuracy(datasets.Metric):
|
||||
def _info(self):
|
||||
return datasets.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"predictions": datasets.Sequence(datasets.Value("int32")),
|
||||
"references": datasets.Sequence(datasets.Value("int32")),
|
||||
}
|
||||
if self.config_name == "multilabel"
|
||||
else {
|
||||
"predictions": datasets.Value("int32"),
|
||||
"references": datasets.Value("int32"),
|
||||
}
|
||||
),
|
||||
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
|
||||
)
|
||||
|
||||
def _compute(self, predictions, references, normalize=True, sample_weight=None):
|
||||
return {
|
||||
"accuracy": float(
|
||||
accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight)
|
||||
)
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2020 The HuggingFace Datasets Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" BLEU metric. """
|
||||
|
||||
import datasets
|
||||
|
||||
from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@INPROCEEDINGS{Papineni02bleu:a,
|
||||
author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu},
|
||||
title = {BLEU: a Method for Automatic Evaluation of Machine Translation},
|
||||
booktitle = {},
|
||||
year = {2002},
|
||||
pages = {311--318}
|
||||
}
|
||||
@inproceedings{lin-och-2004-orange,
|
||||
title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation",
|
||||
author = "Lin, Chin-Yew and
|
||||
Och, Franz Josef",
|
||||
booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics",
|
||||
month = "aug 23{--}aug 27",
|
||||
year = "2004",
|
||||
address = "Geneva, Switzerland",
|
||||
publisher = "COLING",
|
||||
url = "https://www.aclweb.org/anthology/C04-1072",
|
||||
pages = "501--507",
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another.
|
||||
Quality is considered to be the correspondence between a machine's output and that of a human: "the closer a machine translation is to a professional human translation,
|
||||
the better it is" – this is the central idea behind BLEU. BLEU was one of the first metrics to claim a high correlation with human judgements of quality, and
|
||||
remains one of the most popular automated and inexpensive metrics.
|
||||
|
||||
Scores are calculated for individual translated segments—generally sentences—by comparing them with a set of good quality reference translations.
|
||||
Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. Intelligibility or grammatical correctness
|
||||
are not taken into account[citation needed].
|
||||
|
||||
BLEU's output is always a number between 0 and 1. This value indicates how similar the candidate text is to the reference texts, with values closer to 1
|
||||
representing more similar texts. Few human translations will attain a score of 1, since this would indicate that the candidate is identical to one of the
|
||||
reference translations. For this reason, it is not necessary to attain a score of 1. Because there are more opportunities to match, adding additional
|
||||
reference translations will increase the BLEU score.
|
||||
"""
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Computes BLEU score of translated segments against one or more references.
|
||||
Args:
|
||||
predictions: list of translations to score.
|
||||
Each translation should be tokenized into a list of tokens.
|
||||
references: list of lists of references for each translation.
|
||||
Each reference should be tokenized into a list of tokens.
|
||||
max_order: Maximum n-gram order to use when computing BLEU score.
|
||||
smooth: Whether or not to apply Lin et al. 2004 smoothing.
|
||||
Returns:
|
||||
'bleu': bleu score,
|
||||
'precisions': geometric mean of n-gram precisions,
|
||||
'brevity_penalty': brevity penalty,
|
||||
'length_ratio': ratio of lengths,
|
||||
'translation_length': translation_length,
|
||||
'reference_length': reference_length
|
||||
Examples:
|
||||
|
||||
>>> predictions = [
|
||||
... ["hello", "there", "general", "kenobi"], # tokenized prediction of the first sample
|
||||
... ["foo", "bar", "foobar"] # tokenized prediction of the second sample
|
||||
... ]
|
||||
>>> references = [
|
||||
... [["hello", "there", "general", "kenobi"], ["hello", "there", "!"]], # tokenized references for the first sample (2 references)
|
||||
... [["foo", "bar", "foobar"]] # tokenized references for the second sample (1 reference)
|
||||
... ]
|
||||
>>> bleu = datasets.load_metric("bleu")
|
||||
>>> results = bleu.compute(predictions=predictions, references=references)
|
||||
>>> print(results["bleu"])
|
||||
1.0
|
||||
"""
|
||||
|
||||
|
||||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class Bleu(datasets.Metric):
|
||||
def _info(self):
|
||||
return datasets.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"predictions": datasets.Sequence(datasets.Value("string", id="token"), id="sequence"),
|
||||
"references": datasets.Sequence(
|
||||
datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), id="references"
|
||||
),
|
||||
}
|
||||
),
|
||||
codebase_urls=["https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py"],
|
||||
reference_urls=[
|
||||
"https://en.wikipedia.org/wiki/BLEU",
|
||||
"https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213",
|
||||
],
|
||||
)
|
||||
|
||||
def _compute(self, predictions, references, max_order=4, smooth=False):
|
||||
score = compute_bleu(
|
||||
reference_corpus=references, translation_corpus=predictions, max_order=max_order, smooth=smooth
|
||||
)
|
||||
(bleu, precisions, bp, ratio, translation_length, reference_length) = score
|
||||
return {
|
||||
"bleu": bleu,
|
||||
"precisions": precisions,
|
||||
"brevity_penalty": bp,
|
||||
"length_ratio": ratio,
|
||||
"translation_length": translation_length,
|
||||
"reference_length": reference_length,
|
||||
}
|
|
@ -0,0 +1,219 @@
|
|||
# Copyright 2020 The HuggingFace Datasets Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" BLEU metric. """
|
||||
|
||||
import datasets
|
||||
|
||||
#from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@INPROCEEDINGS{Papineni02bleu:a,
|
||||
author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu},
|
||||
title = {BLEU: a Method for Automatic Evaluation of Machine Translation},
|
||||
booktitle = {},
|
||||
year = {2002},
|
||||
pages = {311--318}
|
||||
}
|
||||
@inproceedings{lin-och-2004-orange,
|
||||
title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation",
|
||||
author = "Lin, Chin-Yew and
|
||||
Och, Franz Josef",
|
||||
booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics",
|
||||
month = "aug 23{--}aug 27",
|
||||
year = "2004",
|
||||
address = "Geneva, Switzerland",
|
||||
publisher = "COLING",
|
||||
url = "https://www.aclweb.org/anthology/C04-1072",
|
||||
pages = "501--507",
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another.
|
||||
Quality is considered to be the correspondence between a machine's output and that of a human: "the closer a machine translation is to a professional human translation,
|
||||
the better it is" – this is the central idea behind BLEU. BLEU was one of the first metrics to claim a high correlation with human judgements of quality, and
|
||||
remains one of the most popular automated and inexpensive metrics.
|
||||
|
||||
Scores are calculated for individual translated segments—generally sentences—by comparing them with a set of good quality reference translations.
|
||||
Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. Intelligibility or grammatical correctness
|
||||
are not taken into account[citation needed].
|
||||
|
||||
BLEU's output is always a number between 0 and 1. This value indicates how similar the candidate text is to the reference texts, with values closer to 1
|
||||
representing more similar texts. Few human translations will attain a score of 1, since this would indicate that the candidate is identical to one of the
|
||||
reference translations. For this reason, it is not necessary to attain a score of 1. Because there are more opportunities to match, adding additional
|
||||
reference translations will increase the BLEU score.
|
||||
"""
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Computes BLEU score of translated segments against one or more references.
|
||||
Args:
|
||||
predictions: list of translations to score.
|
||||
Each translation should be tokenized into a list of tokens.
|
||||
references: list of lists of references for each translation.
|
||||
Each reference should be tokenized into a list of tokens.
|
||||
max_order: Maximum n-gram order to use when computing BLEU score.
|
||||
smooth: Whether or not to apply Lin et al. 2004 smoothing.
|
||||
Returns:
|
||||
'bleu': bleu score,
|
||||
'precisions': geometric mean of n-gram precisions,
|
||||
'brevity_penalty': brevity penalty,
|
||||
'length_ratio': ratio of lengths,
|
||||
'translation_length': translation_length,
|
||||
'reference_length': reference_length
|
||||
Examples:
|
||||
|
||||
>>> predictions = [
|
||||
... ["hello", "there", "general", "kenobi"], # tokenized prediction of the first sample
|
||||
... ["foo", "bar", "foobar"] # tokenized prediction of the second sample
|
||||
... ]
|
||||
>>> references = [
|
||||
... [["hello", "there", "general", "kenobi"], ["hello", "there", "!"]], # tokenized references for the first sample (2 references)
|
||||
... [["foo", "bar", "foobar"]] # tokenized references for the second sample (1 reference)
|
||||
... ]
|
||||
>>> bleu = datasets.load_metric("bleu")
|
||||
>>> results = bleu.compute(predictions=predictions, references=references)
|
||||
>>> print(results["bleu"])
|
||||
1.0
|
||||
"""
|
||||
|
||||
import collections
|
||||
import math
|
||||
|
||||
|
||||
def _get_ngrams(segment, max_order):
|
||||
"""Extracts all n-grams upto a given maximum order from an input segment.
|
||||
Args:
|
||||
segment: text segment from which n-grams will be extracted.
|
||||
max_order: maximum length in tokens of the n-grams returned by this
|
||||
methods.
|
||||
Returns:
|
||||
The Counter containing all n-grams upto max_order in segment
|
||||
with a count of how many times each n-gram occurred.
|
||||
"""
|
||||
ngram_counts = collections.Counter()
|
||||
for order in range(1, max_order + 1):
|
||||
for i in range(0, len(segment) - order + 1):
|
||||
ngram = tuple(segment[i:i+order])
|
||||
ngram_counts[ngram] += 1
|
||||
return ngram_counts
|
||||
|
||||
|
||||
def compute_bleu(reference_corpus, translation_corpus, max_order=1,
|
||||
smooth=False):
|
||||
"""Computes BLEU score of translated segments against one or more references.
|
||||
Args:
|
||||
reference_corpus: list of lists of references for each translation. Each
|
||||
reference should be tokenized into a list of tokens.
|
||||
translation_corpus: list of translations to score. Each translation
|
||||
should be tokenized into a list of tokens.
|
||||
max_order: Maximum n-gram order to use when computing BLEU score.
|
||||
smooth: Whether or not to apply Lin et al. 2004 smoothing.
|
||||
Returns:
|
||||
3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
|
||||
precisions and brevity penalty.
|
||||
"""
|
||||
matches_by_order = [0] * max_order
|
||||
possible_matches_by_order = [0] * max_order
|
||||
reference_length = 0
|
||||
translation_length = 0
|
||||
for (references, translation) in zip(reference_corpus,
|
||||
translation_corpus):
|
||||
# translation = references[0]
|
||||
translation = [item.lower() for item in translation]
|
||||
|
||||
reference_length += min(len(r) for r in references)
|
||||
translation_length += len(translation)
|
||||
|
||||
merged_ref_ngram_counts = collections.Counter()
|
||||
|
||||
for reference in references:
|
||||
merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
|
||||
# print(merged_ref_ngram_counts)
|
||||
translation_ngram_counts = _get_ngrams(translation, max_order)
|
||||
overlap = translation_ngram_counts & merged_ref_ngram_counts
|
||||
for ngram in overlap:
|
||||
matches_by_order[len(ngram)-1] += overlap[ngram]
|
||||
for order in range(1, max_order+1):
|
||||
possible_matches = len(translation) - order + 1
|
||||
if possible_matches > 0:
|
||||
possible_matches_by_order[order-1] += possible_matches
|
||||
|
||||
|
||||
precisions = [0] * max_order
|
||||
for i in range(0, max_order):
|
||||
if smooth:
|
||||
precisions[i] = ((matches_by_order[i] + 1.) /
|
||||
(possible_matches_by_order[i] + 1.))
|
||||
else:
|
||||
if possible_matches_by_order[i] > 0:
|
||||
precisions[i] = (float(matches_by_order[i]) /
|
||||
possible_matches_by_order[i])
|
||||
else:
|
||||
precisions[i] = 0.0
|
||||
|
||||
|
||||
if min(precisions) > 0:
|
||||
p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions)
|
||||
geo_mean = math.exp(p_log_sum)
|
||||
else:
|
||||
geo_mean = 0
|
||||
|
||||
ratio = float(translation_length) / reference_length
|
||||
|
||||
if ratio > 1.0:
|
||||
bp = 1.
|
||||
else:
|
||||
bp = math.exp(1 - 1. / ratio)
|
||||
|
||||
bleu = geo_mean * bp
|
||||
|
||||
|
||||
return (bleu, precisions, bp, ratio, translation_length, reference_length)
|
||||
|
||||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class Bleu(datasets.Metric):
|
||||
def _info(self):
|
||||
return datasets.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"predictions": datasets.Sequence(datasets.Value("string", id="token"), id="sequence"),
|
||||
"references": datasets.Sequence(
|
||||
datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), id="references"
|
||||
),
|
||||
}
|
||||
),
|
||||
codebase_urls=["https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py"],
|
||||
reference_urls=[
|
||||
"https://en.wikipedia.org/wiki/BLEU",
|
||||
"https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213",
|
||||
],
|
||||
)
|
||||
|
||||
def _compute(self, predictions, references, max_order=4, smooth=False):
|
||||
score = compute_bleu(
|
||||
reference_corpus=references, translation_corpus=predictions, max_order=1, smooth=smooth
|
||||
)
|
||||
(bleu, precisions, bp, ratio, translation_length, reference_length) = score
|
||||
return {
|
||||
"bleu": bleu,
|
||||
"precisions": precisions,
|
||||
"brevity_penalty": bp,
|
||||
"length_ratio": ratio,
|
||||
"translation_length": translation_length,
|
||||
"reference_length": reference_length,
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright 2020 The HuggingFace Datasets Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" ROUGE metric from Google Research github repo. """
|
||||
|
||||
# The dependencies in https://github.com/google-research/google-research/blob/master/rouge/requirements.txt
|
||||
import absl # Here to have a nice missing dependency error message early on
|
||||
import nltk # Here to have a nice missing dependency error message early on
|
||||
import numpy # Here to have a nice missing dependency error message early on
|
||||
import six # Here to have a nice missing dependency error message early on
|
||||
from .rouge_score import RougeScorer
|
||||
from .scoring import *
|
||||
from .rouge_tokenizers import *
|
||||
import datasets
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@inproceedings{lin-2004-rouge,
|
||||
title = "{ROUGE}: A Package for Automatic Evaluation of Summaries",
|
||||
author = "Lin, Chin-Yew",
|
||||
booktitle = "Text Summarization Branches Out",
|
||||
month = jul,
|
||||
year = "2004",
|
||||
address = "Barcelona, Spain",
|
||||
publisher = "Association for Computational Linguistics",
|
||||
url = "https://www.aclweb.org/anthology/W04-1013",
|
||||
pages = "74--81",
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
ROUGE, or Recall-Oriented Understudy for Gisting Evaluation, is a set of metrics and a software package used for
|
||||
evaluating automatic summarization and machine translation software in natural language processing.
|
||||
The metrics compare an automatically produced summary or translation against a reference or a set of references (human-produced) summary or translation.
|
||||
|
||||
Note that ROUGE is case insensitive, meaning that upper case letters are treated the same way as lower case letters.
|
||||
|
||||
This metrics is a wrapper around Google Research reimplementation of ROUGE:
|
||||
https://github.com/google-research/google-research/tree/master/rouge
|
||||
"""
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Calculates average rouge scores for a list of hypotheses and references
|
||||
Args:
|
||||
predictions: list of predictions to score. Each prediction
|
||||
should be a string with tokens separated by spaces.
|
||||
references: list of reference for each prediction. Each
|
||||
reference should be a string with tokens separated by spaces.
|
||||
rouge_types: A list of rouge types to calculate.
|
||||
Valid names:
|
||||
`"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
|
||||
`"rougeL"`: Longest common subsequence based scoring.
|
||||
`"rougeLSum"`: rougeLsum splits text using `"\n"`.
|
||||
See details in https://github.com/huggingface/datasets/issues/617
|
||||
use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
|
||||
use_aggregator: Return aggregates if this is set to True
|
||||
Returns:
|
||||
rouge1: rouge_1 (precision, recall, f1),
|
||||
rouge2: rouge_2 (precision, recall, f1),
|
||||
rougeL: rouge_l (precision, recall, f1),
|
||||
rougeLsum: rouge_lsum (precision, recall, f1)
|
||||
Examples:
|
||||
|
||||
>>> rouge = datasets.load_metric('rouge')
|
||||
>>> predictions = ["hello there", "general kenobi"]
|
||||
>>> references = ["hello there", "general kenobi"]
|
||||
>>> results = rouge.compute(predictions=predictions, references=references)
|
||||
>>> print(list(results.keys()))
|
||||
['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
|
||||
>>> print(results["rouge1"])
|
||||
AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))
|
||||
>>> print(results["rouge1"].mid.fmeasure)
|
||||
1.0
|
||||
"""
|
||||
|
||||
|
||||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class Rouge(datasets.Metric):
|
||||
def _info(self):
|
||||
return datasets.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"predictions": datasets.Value("string", id="sequence"),
|
||||
"references": datasets.Value("string", id="sequence"),
|
||||
}
|
||||
),
|
||||
codebase_urls=["https://github.com/google-research/google-research/tree/master/rouge"],
|
||||
reference_urls=[
|
||||
"https://en.wikipedia.org/wiki/ROUGE_(metric)",
|
||||
"https://github.com/google-research/google-research/tree/master/rouge",
|
||||
],
|
||||
)
|
||||
|
||||
def _compute(self, predictions, references, rouge_types=None, use_aggregator=True, use_stemmer=False):
|
||||
if rouge_types is None:
|
||||
rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
|
||||
|
||||
scorer = RougeScorer(rouge_types=rouge_types, use_stemmer=use_stemmer)
|
||||
if use_aggregator:
|
||||
aggregator = BootstrapAggregator()
|
||||
else:
|
||||
scores = []
|
||||
|
||||
for ref, pred in zip(references, predictions):
|
||||
score = scorer.score(ref, pred)
|
||||
if use_aggregator:
|
||||
aggregator.add_scores(score)
|
||||
else:
|
||||
scores.append(score)
|
||||
|
||||
if use_aggregator:
|
||||
result = aggregator.aggregate()
|
||||
for key in rouge_types:
|
||||
result[key] = result[key].mid.fmeasure
|
||||
else:
|
||||
result = {}
|
||||
for key in scores[0]:
|
||||
result[key] = list(score[key] for score in scores)
|
||||
|
||||
return result
|
|
@ -0,0 +1,318 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Computes rouge scores between two text blobs.
|
||||
Implementation replicates the functionality in the original ROUGE package. See:
|
||||
Lin, Chin-Yew. ROUGE: a Package for Automatic Evaluation of Summaries. In
|
||||
Proceedings of the Workshop on Text Summarization Branches Out (WAS 2004),
|
||||
Barcelona, Spain, July 25 - 26, 2004.
|
||||
Default options are equivalent to running:
|
||||
ROUGE-1.5.5.pl -e data -n 2 -a settings.xml
|
||||
Or with use_stemmer=True:
|
||||
ROUGE-1.5.5.pl -m -e data -n 2 -a settings.xml
|
||||
In these examples settings.xml lists input files and formats.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
|
||||
from absl import logging
|
||||
import nltk
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import map
|
||||
from six.moves import range
|
||||
from .scoring import *
|
||||
from .rouge_tokenizers import *
|
||||
|
||||
|
||||
|
||||
class RougeScorer(BaseScorer):
|
||||
"""Calculate rouges scores between two blobs of text.
|
||||
Sample usage:
|
||||
scorer = RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
|
||||
scores = scorer.score('The quick brown fox jumps over the lazy dog',
|
||||
'The quick brown dog jumps on the log.')
|
||||
"""
|
||||
|
||||
def __init__(self, rouge_types, use_stemmer=False, split_summaries=False,
|
||||
tokenizer=None):
|
||||
"""Initializes a new RougeScorer.
|
||||
Valid rouge types that can be computed are:
|
||||
rougen (e.g. rouge1, rouge2): n-gram based scoring.
|
||||
rougeL: Longest common subsequence based scoring.
|
||||
Args:
|
||||
rouge_types: A list of rouge types to calculate.
|
||||
use_stemmer: Bool indicating whether Porter stemmer should be used to
|
||||
strip word suffixes to improve matching. This arg is used in the
|
||||
DefaultTokenizer, but other tokenizers might or might not choose to
|
||||
use this.
|
||||
split_summaries: whether to add newlines between sentences for rougeLsum
|
||||
tokenizer: Tokenizer object which has a tokenize() method.
|
||||
Returns:
|
||||
A dict mapping rouge types to Score tuples.
|
||||
"""
|
||||
|
||||
self.rouge_types = rouge_types
|
||||
if tokenizer:
|
||||
self._tokenizer = tokenizer
|
||||
else:
|
||||
self._tokenizer = DefaultTokenizer(use_stemmer)
|
||||
logging.info("Using default tokenizer.")
|
||||
|
||||
self._split_summaries = split_summaries
|
||||
|
||||
def score_multi(self, targets, prediction):
|
||||
"""Calculates rouge scores between targets and prediction.
|
||||
The target with the maximum f-measure is used for the final score for
|
||||
each score type..
|
||||
Args:
|
||||
targets: list of texts containing the targets
|
||||
prediction: Text containing the predicted text.
|
||||
Returns:
|
||||
A dict mapping each rouge type to a Score object.
|
||||
Raises:
|
||||
ValueError: If an invalid rouge type is encountered.
|
||||
"""
|
||||
score_dicts = [self.score(t, prediction) for t in targets]
|
||||
max_score = {}
|
||||
for k in self.rouge_types:
|
||||
index = np.argmax([s[k].fmeasure for s in score_dicts])
|
||||
max_score[k] = score_dicts[index][k]
|
||||
|
||||
return max_score
|
||||
|
||||
def score(self, target, prediction):
|
||||
"""Calculates rouge scores between the target and prediction.
|
||||
Args:
|
||||
target: Text containing the target (ground truth) text,
|
||||
or if a list
|
||||
prediction: Text containing the predicted text.
|
||||
Returns:
|
||||
A dict mapping each rouge type to a Score object.
|
||||
Raises:
|
||||
ValueError: If an invalid rouge type is encountered.
|
||||
"""
|
||||
|
||||
# Pre-compute target tokens and prediction tokens for use by different
|
||||
# types, except if only "rougeLsum" is requested.
|
||||
if len(self.rouge_types) == 1 and self.rouge_types[0] == "rougeLsum":
|
||||
target_tokens = None
|
||||
prediction_tokens = None
|
||||
else:
|
||||
target_tokens = self._tokenizer.tokenize(target)
|
||||
prediction_tokens = self._tokenizer.tokenize(prediction)
|
||||
result = {}
|
||||
|
||||
for rouge_type in self.rouge_types:
|
||||
if rouge_type == "rougeL":
|
||||
# Rouge from longest common subsequences.
|
||||
scores = _score_lcs(target_tokens, prediction_tokens)
|
||||
elif rouge_type == "rougeLsum":
|
||||
# Note: Does not support multi-line text.
|
||||
def get_sents(text):
|
||||
if self._split_summaries:
|
||||
sents = nltk.sent_tokenize(text)
|
||||
else:
|
||||
# Assume sentences are separated by newline.
|
||||
sents = six.ensure_str(text).split("\n")
|
||||
sents = [x for x in sents if len(x)]
|
||||
return sents
|
||||
|
||||
target_tokens_list = [
|
||||
self._tokenizer.tokenize(s) for s in get_sents(target)]
|
||||
prediction_tokens_list = [
|
||||
self._tokenizer.tokenize(s) for s in get_sents(prediction)]
|
||||
|
||||
scores = _summary_level_lcs(target_tokens_list,
|
||||
prediction_tokens_list)
|
||||
elif re.match(r"rouge[0-9]$", six.ensure_str(rouge_type)):
|
||||
# Rouge from n-grams.
|
||||
n = int(rouge_type[5:])
|
||||
if n <= 0:
|
||||
raise ValueError("rougen requires positive n: %s" % rouge_type)
|
||||
target_ngrams = _create_ngrams(target_tokens, n)
|
||||
prediction_ngrams = _create_ngrams(prediction_tokens, n)
|
||||
scores = _score_ngrams(target_ngrams, prediction_ngrams)
|
||||
else:
|
||||
raise ValueError("Invalid rouge type: %s" % rouge_type)
|
||||
result[rouge_type] = scores
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _create_ngrams(tokens, n):
|
||||
"""Creates ngrams from the given list of tokens.
|
||||
Args:
|
||||
tokens: A list of tokens from which ngrams are created.
|
||||
n: Number of tokens to use, e.g. 2 for bigrams.
|
||||
Returns:
|
||||
A dictionary mapping each bigram to the number of occurrences.
|
||||
"""
|
||||
|
||||
ngrams = collections.Counter()
|
||||
for ngram in (tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)):
|
||||
ngrams[ngram] += 1
|
||||
return ngrams
|
||||
|
||||
|
||||
def _score_lcs(target_tokens, prediction_tokens):
|
||||
"""Computes LCS (Longest Common Subsequence) rouge scores.
|
||||
Args:
|
||||
target_tokens: Tokens from the target text.
|
||||
prediction_tokens: Tokens from the predicted text.
|
||||
Returns:
|
||||
A Score object containing computed scores.
|
||||
"""
|
||||
|
||||
if not target_tokens or not prediction_tokens:
|
||||
return Score(precision=0, recall=0, fmeasure=0)
|
||||
|
||||
# Compute length of LCS from the bottom up in a table (DP appproach).
|
||||
lcs_table = _lcs_table(target_tokens, prediction_tokens)
|
||||
lcs_length = lcs_table[-1][-1]
|
||||
|
||||
precision = lcs_length / len(prediction_tokens)
|
||||
recall = lcs_length / len(target_tokens)
|
||||
fmeasure = func_fmeasure(precision, recall)
|
||||
|
||||
return Score(precision=precision, recall=recall, fmeasure=fmeasure)
|
||||
|
||||
|
||||
def _lcs_table(ref, can):
|
||||
"""Create 2-d LCS score table."""
|
||||
rows = len(ref)
|
||||
cols = len(can)
|
||||
lcs_table = [[0] * (cols + 1) for _ in range(rows + 1)]
|
||||
for i in range(1, rows + 1):
|
||||
for j in range(1, cols + 1):
|
||||
if ref[i - 1] == can[j - 1]:
|
||||
lcs_table[i][j] = lcs_table[i - 1][j - 1] + 1
|
||||
else:
|
||||
lcs_table[i][j] = max(lcs_table[i - 1][j], lcs_table[i][j - 1])
|
||||
return lcs_table
|
||||
|
||||
|
||||
def _backtrack_norec(t, ref, can):
|
||||
"""Read out LCS."""
|
||||
i = len(ref)
|
||||
j = len(can)
|
||||
lcs = []
|
||||
while i > 0 and j > 0:
|
||||
if ref[i - 1] == can[j - 1]:
|
||||
lcs.insert(0, i-1)
|
||||
i -= 1
|
||||
j -= 1
|
||||
elif t[i][j - 1] > t[i - 1][j]:
|
||||
j -= 1
|
||||
else:
|
||||
i -= 1
|
||||
return lcs
|
||||
|
||||
|
||||
def _summary_level_lcs(ref_sent, can_sent):
|
||||
"""ROUGE: Summary-level LCS, section 3.2 in ROUGE paper.
|
||||
Args:
|
||||
ref_sent: list of tokenized reference sentences
|
||||
can_sent: list of tokenized candidate sentences
|
||||
Returns:
|
||||
summary level ROUGE score
|
||||
"""
|
||||
if not ref_sent or not can_sent:
|
||||
return Score(precision=0, recall=0, fmeasure=0)
|
||||
|
||||
m = sum(map(len, ref_sent))
|
||||
n = sum(map(len, can_sent))
|
||||
if not n or not m:
|
||||
return Score(precision=0, recall=0, fmeasure=0)
|
||||
|
||||
# get token counts to prevent double counting
|
||||
token_cnts_r = collections.Counter()
|
||||
token_cnts_c = collections.Counter()
|
||||
for s in ref_sent:
|
||||
# s is a list of tokens
|
||||
token_cnts_r.update(s)
|
||||
for s in can_sent:
|
||||
token_cnts_c.update(s)
|
||||
|
||||
hits = 0
|
||||
for r in ref_sent:
|
||||
lcs = _union_lcs(r, can_sent)
|
||||
# Prevent double-counting:
|
||||
# The paper describes just computing hits += len(_union_lcs()),
|
||||
# but the implementation prevents double counting. We also
|
||||
# implement this as in version 1.5.5.
|
||||
for t in lcs:
|
||||
if token_cnts_c[t] > 0 and token_cnts_r[t] > 0:
|
||||
hits += 1
|
||||
token_cnts_c[t] -= 1
|
||||
token_cnts_r[t] -= 1
|
||||
|
||||
recall = hits / m
|
||||
precision = hits / n
|
||||
fmeasure = func_fmeasure(precision, recall)
|
||||
return Score(precision=precision, recall=recall, fmeasure=fmeasure)
|
||||
|
||||
|
||||
def _union_lcs(ref, c_list):
|
||||
"""Find union LCS between a ref sentence and list of candidate sentences.
|
||||
Args:
|
||||
ref: list of tokens
|
||||
c_list: list of list of indices for LCS into reference summary
|
||||
Returns:
|
||||
List of tokens in ref representing union LCS.
|
||||
"""
|
||||
lcs_list = [lcs_ind(ref, c) for c in c_list]
|
||||
return [ref[i] for i in _find_union(lcs_list)]
|
||||
|
||||
|
||||
def _find_union(lcs_list):
|
||||
"""Finds union LCS given a list of LCS."""
|
||||
return sorted(list(set().union(*lcs_list)))
|
||||
|
||||
|
||||
def lcs_ind(ref, can):
|
||||
"""Returns one of the longest lcs."""
|
||||
t = _lcs_table(ref, can)
|
||||
return _backtrack_norec(t, ref, can)
|
||||
|
||||
|
||||
def _score_ngrams(target_ngrams, prediction_ngrams):
|
||||
"""Compute n-gram based rouge scores.
|
||||
Args:
|
||||
target_ngrams: A Counter object mapping each ngram to number of
|
||||
occurrences for the target text.
|
||||
prediction_ngrams: A Counter object mapping each ngram to number of
|
||||
occurrences for the prediction text.
|
||||
Returns:
|
||||
A Score object containing computed scores.
|
||||
"""
|
||||
|
||||
intersection_ngrams_count = 0
|
||||
for ngram in six.iterkeys(target_ngrams):
|
||||
intersection_ngrams_count += min(target_ngrams[ngram],
|
||||
prediction_ngrams[ngram])
|
||||
target_ngrams_count = sum(target_ngrams.values())
|
||||
prediction_ngrams_count = sum(prediction_ngrams.values())
|
||||
|
||||
precision = intersection_ngrams_count / max(prediction_ngrams_count, 1)
|
||||
recall = intersection_ngrams_count / max(target_ngrams_count, 1)
|
||||
fmeasure = func_fmeasure(precision, recall)
|
||||
|
||||
return Score(precision=precision, recall=recall, fmeasure=fmeasure)
|
|
@ -0,0 +1,90 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
"""Library containing Tokenizer definitions.
|
||||
The RougeScorer class can be instantiated with the tokenizers defined here. New
|
||||
tokenizers can be defined by creating a subclass of the Tokenizer abstract class
|
||||
and overriding the tokenize() method.
|
||||
"""
|
||||
import abc
|
||||
from nltk.stem import porter
|
||||
|
||||
|
||||
import re
|
||||
import six
|
||||
|
||||
|
||||
# Pre-compile regexes that are use often
|
||||
NON_ALPHANUM_PATTERN = r"[^a-z0-9]+"
|
||||
NON_ALPHANUM_RE = re.compile(NON_ALPHANUM_PATTERN)
|
||||
SPACES_PATTERN = r"\s+"
|
||||
SPACES_RE = re.compile(SPACES_PATTERN)
|
||||
VALID_TOKEN_PATTERN = r"^[a-z0-9]+$"
|
||||
VALID_TOKEN_RE = re.compile(VALID_TOKEN_PATTERN)
|
||||
|
||||
|
||||
def _tokenize(text, stemmer):
|
||||
"""Tokenize input text into a list of tokens.
|
||||
This approach aims to replicate the approach taken by Chin-Yew Lin in
|
||||
the original ROUGE implementation.
|
||||
Args:
|
||||
text: A text blob to tokenize.
|
||||
stemmer: An optional stemmer.
|
||||
Returns:
|
||||
A list of string tokens extracted from input text.
|
||||
"""
|
||||
|
||||
# Convert everything to lowercase.
|
||||
text = text.lower()
|
||||
# Replace any non-alpha-numeric characters with spaces.
|
||||
text = NON_ALPHANUM_RE.sub(" ", six.ensure_str(text))
|
||||
|
||||
tokens = SPACES_RE.split(text)
|
||||
if stemmer:
|
||||
# Only stem words more than 3 characters long.
|
||||
tokens = [six.ensure_str(stemmer.stem(x)) if len(x) > 3 else x
|
||||
for x in tokens]
|
||||
|
||||
# One final check to drop any empty or invalid tokens.
|
||||
tokens = [x for x in tokens if VALID_TOKEN_RE.match(x)]
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
class Tokenizer(abc.ABC):
|
||||
"""Abstract base class for a tokenizer.
|
||||
Subclasses of Tokenizer must implement the tokenize() method.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def tokenize(self, text):
|
||||
raise NotImplementedError("Tokenizer must override tokenize() method")
|
||||
|
||||
|
||||
class DefaultTokenizer(Tokenizer):
|
||||
"""Default tokenizer which tokenizes on whitespace."""
|
||||
|
||||
def __init__(self, use_stemmer=False):
|
||||
"""Constructor for DefaultTokenizer.
|
||||
Args:
|
||||
use_stemmer: boolean, indicating whether Porter stemmer should be used to
|
||||
strip word suffixes to improve matching.
|
||||
"""
|
||||
self._stemmer = porter.PorterStemmer() if use_stemmer else None
|
||||
|
||||
def tokenize(self, text):
|
||||
return _tokenize(text, self._stemmer)
|
|
@ -0,0 +1,158 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Library for scoring and evaluation of text samples.
|
||||
Aggregation functions use bootstrap resampling to compute confidence intervals
|
||||
as per the original ROUGE perl implementation.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import collections
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import range
|
||||
|
||||
|
||||
class Score(
|
||||
collections.namedtuple("Score", ["precision", "recall", "fmeasure"])):
|
||||
"""Tuple containing precision, recall, and f-measure values."""
|
||||
|
||||
|
||||
class BaseScorer(object, metaclass=abc.ABCMeta):
|
||||
"""Base class for Scorer objects."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def score(self, target, prediction):
|
||||
"""Calculates score between the target and prediction.
|
||||
Args:
|
||||
target: Text containing the target (ground truth) text.
|
||||
prediction: Text containing the predicted text.
|
||||
Returns:
|
||||
A dict mapping each score_type (string) to Score object.
|
||||
"""
|
||||
|
||||
|
||||
class AggregateScore(
|
||||
collections.namedtuple("AggregateScore", ["low", "mid", "high"])):
|
||||
"""Tuple containing confidence intervals for scores."""
|
||||
|
||||
|
||||
class BootstrapAggregator(object):
|
||||
"""Aggregates scores to provide confidence intervals.
|
||||
Sample usage:
|
||||
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'])
|
||||
aggregator = Aggregator()
|
||||
aggregator.add_scores(scorer.score("one two three", "one two"))
|
||||
aggregator.add_scores(scorer.score("one two five six", "seven eight"))
|
||||
result = aggregator.aggregate()
|
||||
print result
|
||||
{'rougeL': AggregateScore(
|
||||
low=Score(precision=0.0, recall=0.0, fmeasure=0.0),
|
||||
mid=Score(precision=0.5, recall=0.33, fmeasure=0.40),
|
||||
high=Score(precision=1.0, recall=0.66, fmeasure=0.80)),
|
||||
'rouge1': AggregateScore(
|
||||
low=Score(precision=0.0, recall=0.0, fmeasure=0.0),
|
||||
mid=Score(precision=0.5, recall=0.33, fmeasure=0.40),
|
||||
high=Score(precision=1.0, recall=0.66, fmeasure=0.80))}
|
||||
"""
|
||||
|
||||
def __init__(self, confidence_interval=0.95, n_samples=1000):
|
||||
"""Initializes a BootstrapAggregator object.
|
||||
Args:
|
||||
confidence_interval: Confidence interval to compute on the mean as a
|
||||
decimal.
|
||||
n_samples: Number of samples to use for bootstrap resampling.
|
||||
Raises:
|
||||
ValueError: If invalid argument is given.
|
||||
"""
|
||||
|
||||
if confidence_interval < 0 or confidence_interval > 1:
|
||||
raise ValueError("confidence_interval must be in range [0, 1]")
|
||||
if n_samples <= 0:
|
||||
raise ValueError("n_samples must be positive")
|
||||
|
||||
self._n_samples = n_samples
|
||||
self._confidence_interval = confidence_interval
|
||||
self._scores = collections.defaultdict(list)
|
||||
|
||||
def add_scores(self, scores):
|
||||
"""Adds a sample for future aggregation.
|
||||
Args:
|
||||
scores: Dict mapping score_type strings to a namedtuple object/class
|
||||
representing a score.
|
||||
"""
|
||||
|
||||
for score_type, score in six.iteritems(scores):
|
||||
self._scores[score_type].append(score)
|
||||
|
||||
def aggregate(self):
|
||||
"""Aggregates scores previously added using add_scores.
|
||||
Returns:
|
||||
A dict mapping score_type to AggregateScore objects.
|
||||
"""
|
||||
|
||||
result = {}
|
||||
for score_type, scores in six.iteritems(self._scores):
|
||||
# Stack scores into a 2-d matrix of (sample, measure).
|
||||
score_matrix = np.vstack(tuple(scores))
|
||||
# Percentiles are returned as (interval, measure).
|
||||
percentiles = self._bootstrap_resample(score_matrix)
|
||||
# Extract the three intervals (low, mid, high).
|
||||
intervals = tuple(
|
||||
(scores[0].__class__(*percentiles[j, :]) for j in range(3)))
|
||||
result[score_type] = AggregateScore(
|
||||
low=intervals[0], mid=intervals[1], high=intervals[2])
|
||||
return result
|
||||
|
||||
def _bootstrap_resample(self, matrix):
|
||||
"""Performs bootstrap resampling on a matrix of scores.
|
||||
Args:
|
||||
matrix: A 2-d matrix of (sample, measure).
|
||||
Returns:
|
||||
A 2-d matrix of (bounds, measure). There are three bounds: low (row 0),
|
||||
mid (row 1) and high (row 2). Mid is always the mean, while low and high
|
||||
bounds are specified by self._confidence_interval (which defaults to 0.95
|
||||
meaning it will return the 2.5th and 97.5th percentiles for a 95%
|
||||
confidence interval on the mean).
|
||||
"""
|
||||
|
||||
# Matrix of (bootstrap sample, measure).
|
||||
sample_mean = np.zeros((self._n_samples, matrix.shape[1]))
|
||||
for i in range(self._n_samples):
|
||||
sample_idx = np.random.choice(
|
||||
np.arange(matrix.shape[0]), size=matrix.shape[0])
|
||||
sample = matrix[sample_idx, :]
|
||||
sample_mean[i, :] = np.mean(sample, axis=0)
|
||||
|
||||
# Take percentiles on the estimate of the mean using bootstrap samples.
|
||||
# Final result is a (bounds, measure) matrix.
|
||||
percentile_delta = (1 - self._confidence_interval) / 2
|
||||
q = 100 * np.array([percentile_delta, 0.5, 1 - percentile_delta])
|
||||
return np.percentile(sample_mean, q, axis=0)
|
||||
|
||||
|
||||
def func_fmeasure(precision, recall):
|
||||
"""Computes f-measure given precision and recall values."""
|
||||
|
||||
if precision + recall > 0:
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
else:
|
||||
return 0.0
|
|
@ -0,0 +1,92 @@
|
|||
""" Official evaluation script for v1.1 of the SQuAD dataset. """
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def f1_score(prediction, ground_truth):
|
||||
prediction_tokens = normalize_answer(prediction).split()
|
||||
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def exact_match_score(prediction, ground_truth):
|
||||
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
scores_for_ground_truths = []
|
||||
for ground_truth in ground_truths:
|
||||
score = metric_fn(prediction, ground_truth)
|
||||
scores_for_ground_truths.append(score)
|
||||
return max(scores_for_ground_truths)
|
||||
|
||||
|
||||
def evaluate(dataset, predictions):
|
||||
f1 = exact_match = total = 0
|
||||
for article in dataset:
|
||||
for paragraph in article["paragraphs"]:
|
||||
for qa in paragraph["qas"]:
|
||||
total += 1
|
||||
if qa["id"] not in predictions:
|
||||
message = "Unanswered question " + qa["id"] + " will receive score 0."
|
||||
print(message, file=sys.stderr)
|
||||
continue
|
||||
ground_truths = list(map(lambda x: x["text"], qa["answers"]))
|
||||
prediction = predictions[qa["id"]]
|
||||
exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
|
||||
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
|
||||
|
||||
exact_match = 100.0 * exact_match / total
|
||||
f1 = 100.0 * f1 / total
|
||||
|
||||
return {"exact_match": exact_match, "f1": f1}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
expected_version = "1.1"
|
||||
parser = argparse.ArgumentParser(description="Evaluation for SQuAD " + expected_version)
|
||||
parser.add_argument("dataset_file", help="Dataset file")
|
||||
parser.add_argument("prediction_file", help="Prediction File")
|
||||
args = parser.parse_args()
|
||||
with open(args.dataset_file) as dataset_file:
|
||||
dataset_json = json.load(dataset_file)
|
||||
if dataset_json["version"] != expected_version:
|
||||
print(
|
||||
"Evaluation expects v-" + expected_version + ", but got dataset with v-" + dataset_json["version"],
|
||||
file=sys.stderr,
|
||||
)
|
||||
dataset = dataset_json["data"]
|
||||
with open(args.prediction_file) as prediction_file:
|
||||
predictions = json.load(prediction_file)
|
||||
print(json.dumps(evaluate(dataset, predictions)))
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2020 The HuggingFace Datasets Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" SQuAD metric. """
|
||||
|
||||
import datasets
|
||||
|
||||
from .evaluate import evaluate
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@inproceedings{Rajpurkar2016SQuAD10,
|
||||
title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text},
|
||||
author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang},
|
||||
booktitle={EMNLP},
|
||||
year={2016}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """
|
||||
This metric wrap the official scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD).
|
||||
|
||||
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by
|
||||
crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span,
|
||||
from the corresponding reading passage, or the question might be unanswerable.
|
||||
"""
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Computes SQuAD scores (F1 and EM).
|
||||
Args:
|
||||
predictions: List of question-answers dictionaries with the following key-values:
|
||||
- 'id': id of the question-answer pair as given in the references (see below)
|
||||
- 'prediction_text': the text of the answer
|
||||
references: List of question-answers dictionaries with the following key-values:
|
||||
- 'id': id of the question-answer pair (see above),
|
||||
- 'answers': a Dict in the SQuAD dataset format
|
||||
{
|
||||
'text': list of possible texts for the answer, as a list of strings
|
||||
'answer_start': list of start positions for the answer, as a list of ints
|
||||
}
|
||||
Note that answer_start values are not taken into account to compute the metric.
|
||||
Returns:
|
||||
'exact_match': Exact match (the normalized answer exactly match the gold answer)
|
||||
'f1': The F-score of predicted tokens versus the gold answer
|
||||
Examples:
|
||||
|
||||
>>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22'}]
|
||||
>>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}]
|
||||
>>> squad_metric = datasets.load_metric("squad")
|
||||
>>> results = squad_metric.compute(predictions=predictions, references=references)
|
||||
>>> print(results)
|
||||
{'exact_match': 100.0, 'f1': 100.0}
|
||||
"""
|
||||
|
||||
|
||||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class Squad(datasets.Metric):
|
||||
def _info(self):
|
||||
return datasets.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"predictions": {"id": datasets.Value("string"), "prediction_text": datasets.Value("string")},
|
||||
"references": {
|
||||
"id": datasets.Value("string"),
|
||||
"answers": datasets.features.Sequence(
|
||||
{
|
||||
"text": datasets.Value("string"),
|
||||
"answer_start": datasets.Value("int32"),
|
||||
}
|
||||
),
|
||||
},
|
||||
}
|
||||
),
|
||||
codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
|
||||
reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
|
||||
)
|
||||
|
||||
def _compute(self, predictions, references):
|
||||
pred_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions}
|
||||
dataset = [
|
||||
{
|
||||
"paragraphs": [
|
||||
{
|
||||
"qas": [
|
||||
{
|
||||
"answers": [{"text": answer_text} for answer_text in ref["answers"]["text"]],
|
||||
"id": ref["id"],
|
||||
}
|
||||
for ref in references
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
score = evaluate(dataset=dataset, predictions=pred_dict)
|
||||
return score
|
|
@ -0,0 +1,330 @@
|
|||
"""Official evaluation script for SQuAD version 2.0.
|
||||
|
||||
In addition to basic functionality, we also compute additional statistics and
|
||||
plot precision-recall curves if an additional na_prob.json file is provided.
|
||||
This file is expected to map question ID's to the model's predicted probability
|
||||
that a question is unanswerable.
|
||||
"""
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
OPTS = None
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("Official evaluation script for SQuAD version 2.0.")
|
||||
parser.add_argument("data_file", metavar="data.json", help="Input data JSON file.")
|
||||
parser.add_argument("pred_file", metavar="pred.json", help="Model predictions.")
|
||||
parser.add_argument(
|
||||
"--out-file", "-o", metavar="eval.json", help="Write accuracy metrics to file (default is stdout)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--na-prob-file", "-n", metavar="na_prob.json", help="Model estimates of probability of no answer."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--na-prob-thresh",
|
||||
"-t",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help='Predict "" if no-answer probability exceeds this (default = 1.0).',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-image-dir", "-p", metavar="out_images", default=None, help="Save precision-recall curves to directory."
|
||||
)
|
||||
parser.add_argument("--verbose", "-v", action="store_true")
|
||||
if len(sys.argv) == 1:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_qid_to_has_ans(dataset):
|
||||
qid_to_has_ans = {}
|
||||
for article in dataset:
|
||||
for p in article["paragraphs"]:
|
||||
for qa in p["qas"]:
|
||||
qid_to_has_ans[qa["id"]] = bool(qa["answers"]["text"])
|
||||
return qid_to_has_ans
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
def remove_articles(text):
|
||||
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
||||
return re.sub(regex, " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def get_tokens(s):
|
||||
if not s:
|
||||
return []
|
||||
return normalize_answer(s).split()
|
||||
|
||||
|
||||
def compute_exact(a_gold, a_pred):
|
||||
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
|
||||
|
||||
|
||||
def compute_f1(a_gold, a_pred):
|
||||
gold_toks = get_tokens(a_gold)
|
||||
pred_toks = get_tokens(a_pred)
|
||||
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
|
||||
num_same = sum(common.values())
|
||||
if len(gold_toks) == 0 or len(pred_toks) == 0:
|
||||
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
||||
return int(gold_toks == pred_toks)
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(pred_toks)
|
||||
recall = 1.0 * num_same / len(gold_toks)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def get_raw_scores(dataset, preds):
|
||||
exact_scores = {}
|
||||
f1_scores = {}
|
||||
for article in dataset:
|
||||
for p in article["paragraphs"]:
|
||||
for qa in p["qas"]:
|
||||
qid = qa["id"]
|
||||
if qid not in preds:
|
||||
print(f"Missing prediction for {qid}")
|
||||
continue
|
||||
a_pred = preds[qid]
|
||||
|
||||
gold_answers = [t for t in qa["answers"]["text"] if normalize_answer(t)]
|
||||
if not gold_answers:
|
||||
# For unanswerable questions, only correct answer is empty string
|
||||
gold_answers = [""]
|
||||
if a_pred != "":
|
||||
exact_scores[qid] = 0
|
||||
f1_scores[qid] = 0
|
||||
else:
|
||||
exact_scores[qid] = 1
|
||||
f1_scores[qid] = 1
|
||||
else:
|
||||
# Take max over all gold answers
|
||||
exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
|
||||
f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
|
||||
return exact_scores, f1_scores
|
||||
|
||||
|
||||
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
|
||||
new_scores = {}
|
||||
for qid, s in scores.items():
|
||||
pred_na = na_probs[qid] > na_prob_thresh
|
||||
if pred_na:
|
||||
new_scores[qid] = float(not qid_to_has_ans[qid])
|
||||
else:
|
||||
new_scores[qid] = s
|
||||
return new_scores
|
||||
|
||||
|
||||
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
||||
if not qid_list:
|
||||
total = len(exact_scores)
|
||||
return collections.OrderedDict(
|
||||
[
|
||||
("exact", 100.0 * sum(exact_scores.values()) / total),
|
||||
("f1", 100.0 * sum(f1_scores.values()) / total),
|
||||
("total", total),
|
||||
]
|
||||
)
|
||||
else:
|
||||
total = len(qid_list)
|
||||
return collections.OrderedDict(
|
||||
[
|
||||
("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
||||
("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
||||
("total", total),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def merge_eval(main_eval, new_eval, prefix):
|
||||
for k in new_eval:
|
||||
main_eval[f"{prefix}_{k}"] = new_eval[k]
|
||||
|
||||
|
||||
def plot_pr_curve(precisions, recalls, out_image, title):
|
||||
plt.step(recalls, precisions, color="b", alpha=0.2, where="post")
|
||||
plt.fill_between(recalls, precisions, step="post", alpha=0.2, color="b")
|
||||
plt.xlabel("Recall")
|
||||
plt.ylabel("Precision")
|
||||
plt.xlim([0.0, 1.05])
|
||||
plt.ylim([0.0, 1.05])
|
||||
plt.title(title)
|
||||
plt.savefig(out_image)
|
||||
plt.clf()
|
||||
|
||||
|
||||
def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None):
|
||||
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||
true_pos = 0.0
|
||||
cur_p = 1.0
|
||||
cur_r = 0.0
|
||||
precisions = [1.0]
|
||||
recalls = [0.0]
|
||||
avg_prec = 0.0
|
||||
for i, qid in enumerate(qid_list):
|
||||
if qid_to_has_ans[qid]:
|
||||
true_pos += scores[qid]
|
||||
cur_p = true_pos / float(i + 1)
|
||||
cur_r = true_pos / float(num_true_pos)
|
||||
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]:
|
||||
# i.e., if we can put a threshold after this point
|
||||
avg_prec += cur_p * (cur_r - recalls[-1])
|
||||
precisions.append(cur_p)
|
||||
recalls.append(cur_r)
|
||||
if out_image:
|
||||
plot_pr_curve(precisions, recalls, out_image, title)
|
||||
return {"ap": 100.0 * avg_prec}
|
||||
|
||||
|
||||
def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, out_image_dir):
|
||||
if out_image_dir and not os.path.exists(out_image_dir):
|
||||
os.makedirs(out_image_dir)
|
||||
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
|
||||
if num_true_pos == 0:
|
||||
return
|
||||
pr_exact = make_precision_recall_eval(
|
||||
exact_raw,
|
||||
na_probs,
|
||||
num_true_pos,
|
||||
qid_to_has_ans,
|
||||
out_image=os.path.join(out_image_dir, "pr_exact.png"),
|
||||
title="Precision-Recall curve for Exact Match score",
|
||||
)
|
||||
pr_f1 = make_precision_recall_eval(
|
||||
f1_raw,
|
||||
na_probs,
|
||||
num_true_pos,
|
||||
qid_to_has_ans,
|
||||
out_image=os.path.join(out_image_dir, "pr_f1.png"),
|
||||
title="Precision-Recall curve for F1 score",
|
||||
)
|
||||
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
|
||||
pr_oracle = make_precision_recall_eval(
|
||||
oracle_scores,
|
||||
na_probs,
|
||||
num_true_pos,
|
||||
qid_to_has_ans,
|
||||
out_image=os.path.join(out_image_dir, "pr_oracle.png"),
|
||||
title="Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)",
|
||||
)
|
||||
merge_eval(main_eval, pr_exact, "pr_exact")
|
||||
merge_eval(main_eval, pr_f1, "pr_f1")
|
||||
merge_eval(main_eval, pr_oracle, "pr_oracle")
|
||||
|
||||
|
||||
def histogram_na_prob(na_probs, qid_list, image_dir, name):
|
||||
if not qid_list:
|
||||
return
|
||||
x = [na_probs[k] for k in qid_list]
|
||||
weights = np.ones_like(x) / float(len(x))
|
||||
plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
|
||||
plt.xlabel("Model probability of no-answer")
|
||||
plt.ylabel("Proportion of dataset")
|
||||
plt.title(f"Histogram of no-answer probability: {name}")
|
||||
plt.savefig(os.path.join(image_dir, f"na_prob_hist_{name}.png"))
|
||||
plt.clf()
|
||||
|
||||
|
||||
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
||||
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
||||
cur_score = num_no_ans
|
||||
best_score = cur_score
|
||||
best_thresh = 0.0
|
||||
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||
for i, qid in enumerate(qid_list):
|
||||
if qid not in scores:
|
||||
continue
|
||||
if qid_to_has_ans[qid]:
|
||||
diff = scores[qid]
|
||||
else:
|
||||
if preds[qid]:
|
||||
diff = -1
|
||||
else:
|
||||
diff = 0
|
||||
cur_score += diff
|
||||
if cur_score > best_score:
|
||||
best_score = cur_score
|
||||
best_thresh = na_probs[qid]
|
||||
return 100.0 * best_score / len(scores), best_thresh
|
||||
|
||||
|
||||
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||
main_eval["best_exact"] = best_exact
|
||||
main_eval["best_exact_thresh"] = exact_thresh
|
||||
main_eval["best_f1"] = best_f1
|
||||
main_eval["best_f1_thresh"] = f1_thresh
|
||||
|
||||
|
||||
def main():
|
||||
with open(OPTS.data_file) as f:
|
||||
dataset_json = json.load(f)
|
||||
dataset = dataset_json["data"]
|
||||
with open(OPTS.pred_file) as f:
|
||||
preds = json.load(f)
|
||||
if OPTS.na_prob_file:
|
||||
with open(OPTS.na_prob_file) as f:
|
||||
na_probs = json.load(f)
|
||||
else:
|
||||
na_probs = {k: 0.0 for k in preds}
|
||||
qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
|
||||
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
||||
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
||||
exact_raw, f1_raw = get_raw_scores(dataset, preds)
|
||||
exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh)
|
||||
f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh)
|
||||
out_eval = make_eval_dict(exact_thresh, f1_thresh)
|
||||
if has_ans_qids:
|
||||
has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
|
||||
merge_eval(out_eval, has_ans_eval, "HasAns")
|
||||
if no_ans_qids:
|
||||
no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
|
||||
merge_eval(out_eval, no_ans_eval, "NoAns")
|
||||
if OPTS.na_prob_file:
|
||||
find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
|
||||
if OPTS.na_prob_file and OPTS.out_image_dir:
|
||||
run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, OPTS.out_image_dir)
|
||||
histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, "hasAns")
|
||||
histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, "noAns")
|
||||
if OPTS.out_file:
|
||||
with open(OPTS.out_file, "w") as f:
|
||||
json.dump(out_eval, f)
|
||||
else:
|
||||
print(json.dumps(out_eval, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
OPTS = parse_args()
|
||||
if OPTS.out_image_dir:
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
main()
|
|
@ -0,0 +1,139 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Datasets Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" SQuAD v2 metric. """
|
||||
|
||||
import datasets
|
||||
|
||||
from .evaluate import (
|
||||
apply_no_ans_threshold,
|
||||
find_all_best_thresh,
|
||||
get_raw_scores,
|
||||
make_eval_dict,
|
||||
make_qid_to_has_ans,
|
||||
merge_eval,
|
||||
)
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@inproceedings{Rajpurkar2016SQuAD10,
|
||||
title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text},
|
||||
author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang},
|
||||
booktitle={EMNLP},
|
||||
year={2016}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """
|
||||
This metric wrap the official scoring script for version 2 of the Stanford Question
|
||||
Answering Dataset (SQuAD).
|
||||
|
||||
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by
|
||||
crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span,
|
||||
from the corresponding reading passage, or the question might be unanswerable.
|
||||
|
||||
SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable questions
|
||||
written adversarially by crowdworkers to look similar to answerable ones.
|
||||
To do well on SQuAD2.0, systems must not only answer questions when possible, but also
|
||||
determine when no answer is supported by the paragraph and abstain from answering.
|
||||
"""
|
||||
|
||||
_KWARGS_DESCRIPTION = """
|
||||
Computes SQuAD v2 scores (F1 and EM).
|
||||
Args:
|
||||
predictions: List of triple for question-answers to score with the following elements:
|
||||
- the question-answer 'id' field as given in the references (see below)
|
||||
- the text of the answer
|
||||
- the probability that the question has no answer
|
||||
references: List of question-answers dictionaries with the following key-values:
|
||||
- 'id': id of the question-answer pair (see above),
|
||||
- 'answers': a list of Dict {'text': text of the answer as a string}
|
||||
no_answer_threshold: float
|
||||
Probability threshold to decide that a question has no answer.
|
||||
Returns:
|
||||
'exact': Exact match (the normalized answer exactly match the gold answer)
|
||||
'f1': The F-score of predicted tokens versus the gold answer
|
||||
'total': Number of score considered
|
||||
'HasAns_exact': Exact match (the normalized answer exactly match the gold answer)
|
||||
'HasAns_f1': The F-score of predicted tokens versus the gold answer
|
||||
'HasAns_total': Number of score considered
|
||||
'NoAns_exact': Exact match (the normalized answer exactly match the gold answer)
|
||||
'NoAns_f1': The F-score of predicted tokens versus the gold answer
|
||||
'NoAns_total': Number of score considered
|
||||
'best_exact': Best exact match (with varying threshold)
|
||||
'best_exact_thresh': No-answer probability threshold associated to the best exact match
|
||||
'best_f1': Best F1 (with varying threshold)
|
||||
'best_f1_thresh': No-answer probability threshold associated to the best F1
|
||||
Examples:
|
||||
|
||||
>>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22', 'no_answer_probability': 0.}]
|
||||
>>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}]
|
||||
>>> squad_v2_metric = datasets.load_metric("squad_v2")
|
||||
>>> results = squad_v2_metric.compute(predictions=predictions, references=references)
|
||||
>>> print(results)
|
||||
{'exact': 100.0, 'f1': 100.0, 'total': 1, 'HasAns_exact': 100.0, 'HasAns_f1': 100.0, 'HasAns_total': 1, 'best_exact': 100.0, 'best_exact_thresh': 0.0, 'best_f1': 100.0, 'best_f1_thresh': 0.0}
|
||||
"""
|
||||
|
||||
|
||||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||
class SquadV2Local(datasets.Metric):
|
||||
def _info(self):
|
||||
return datasets.MetricInfo(
|
||||
description=_DESCRIPTION,
|
||||
citation=_CITATION,
|
||||
inputs_description=_KWARGS_DESCRIPTION,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"predictions": {
|
||||
"id": datasets.Value("string"),
|
||||
"prediction_text": datasets.Value("string"),
|
||||
"no_answer_probability": datasets.Value("float32"),
|
||||
},
|
||||
"references": {
|
||||
"id": datasets.Value("string"),
|
||||
"answers": datasets.features.Sequence(
|
||||
{"text": datasets.Value("string"), "answer_start": datasets.Value("int32")}
|
||||
),
|
||||
},
|
||||
}
|
||||
),
|
||||
codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
|
||||
reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
|
||||
)
|
||||
|
||||
def _compute(self, predictions, references, no_answer_threshold=1.0):
|
||||
# no_answer_probabilities = dict((p["id"], p["no_answer_probability"]) for p in predictions)
|
||||
dataset = [{"paragraphs": [{"qas": references}]}]
|
||||
predictions = dict((p["id"], p["prediction_text"]) for p in predictions)
|
||||
|
||||
# qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
|
||||
# has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
||||
# no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
||||
|
||||
exact_raw, f1_raw = get_raw_scores(dataset, predictions)
|
||||
out_eval = make_eval_dict(exact_raw, f1_raw)
|
||||
#
|
||||
#
|
||||
# exact_thresh = apply_no_ans_threshold(exact_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold)
|
||||
# f1_thresh = apply_no_ans_threshold(f1_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold)
|
||||
# out_eval = make_eval_dict(exact_thresh, f1_thresh)
|
||||
#
|
||||
# if has_ans_qids:
|
||||
# has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
|
||||
# merge_eval(out_eval, has_ans_eval, "HasAns")
|
||||
# if no_ans_qids:
|
||||
# no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
|
||||
# merge_eval(out_eval, no_ans_eval, "NoAns")
|
||||
# find_all_best_thresh(out_eval, predictions, exact_raw, f1_raw, no_answer_probabilities, qid_to_has_ans)
|
||||
return dict(out_eval)
|
|
@ -0,0 +1,192 @@
|
|||
import sys
|
||||
sys.path.append('../data_process')
|
||||
from typing import List, Optional, Tuple
|
||||
from QAInput import SimpleQAInput as QAInput
|
||||
#from QAInput import QAInputNoPrompt as QAInput
|
||||
|
||||
def preprocess_sqaud_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_sqaud_abstractive_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_boolq_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
question_column, context_column, answer_column = 'question', 'passage', 'answer'
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [str(ans) for ans in answers] #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
# print(inputs,targets)
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_boolq_batch_pretrain(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0].capitalize() if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
# print(inputs,targets)
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_narrativeqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = [exp['summary']['text'] for exp in examples['document']]
|
||||
questions = [exp['text'] for exp in examples['question']]
|
||||
answers = [ans[0]['text'] for ans in examples['answers']]
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = answers #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_narrativeqa_batch_pretrain(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_drop_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = examples['passage']
|
||||
questions = examples['question']
|
||||
answers = examples['answers']
|
||||
inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_race_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = examples['article']
|
||||
questions = examples['question']
|
||||
all_options = examples['options']
|
||||
answers = examples['answer']
|
||||
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}; D. {options[3]}' for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
|
||||
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3 }
|
||||
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_newsqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)]
|
||||
# inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
|
||||
def preprocess_ropes_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples[question_column]
|
||||
backgrounds = examples["background"]
|
||||
situations = examples["situation"]
|
||||
answers = examples[answer_column]
|
||||
|
||||
inputs = [QAInput.qg_input_extractive_qa(" ".join([background, situation]), question) for question, background, situation in zip(questions, backgrounds, situations)]
|
||||
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_openbookqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
questions = examples['question_stem']
|
||||
all_options = examples['choices']
|
||||
answers = examples['answerKey']
|
||||
options_texts = [f"options: A. {options['text'][0]}; B. {options['text'][1]}; C. {options['text'][2]}; D. {options['text'][3]}" for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc("", question, ops) for question, ops in zip(questions, options_texts)]
|
||||
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
|
||||
targets = [options['text'][ans_map[answer]] for options, answer in zip(all_options, answers)]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_social_iqa_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = examples['article']
|
||||
questions = examples['question']
|
||||
all_options = examples['options']
|
||||
answers = examples['answer']
|
||||
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
|
||||
ans_map = {'A': 0, 'B': 1, 'C': 2,}
|
||||
targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)]
|
||||
return inputs, targets
|
||||
|
||||
def preprocess_dream_batch(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column: str,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
contexts = [" ".join(dialogue) for dialogue in examples['dialogue']]
|
||||
questions = examples['question']
|
||||
all_options = examples['choice']
|
||||
answers = examples['answer']
|
||||
answer_idxs = [options.index(answer) for answer, options in zip(answers, all_options)]
|
||||
options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options]
|
||||
inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)]
|
||||
targets = answers
|
||||
return inputs, targets
|
|
@ -0,0 +1,248 @@
|
|||
import itertools
|
||||
import json
|
||||
import os
|
||||
import csv
|
||||
import errno
|
||||
import random
|
||||
from random import shuffle
|
||||
from typing import List
|
||||
#import spacy
|
||||
from tqdm import tqdm
|
||||
import codecs
|
||||
#import nltk
|
||||
import glob
|
||||
import xml.etree.ElementTree as ET
|
||||
from datasets import load_dataset
|
||||
#import statistic
|
||||
from QAInput import StructuralQAInput, SimpleQAInput
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
M2M100Tokenizer,
|
||||
MBart50Tokenizer,
|
||||
EvalPrediction,
|
||||
MBart50TokenizerFast,
|
||||
MBartTokenizer,
|
||||
MBartTokenizerFast,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
default_data_collator,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("./t5-base")
|
||||
#nltk.download("stopwords")
|
||||
#from nltk.corpus import stopwords
|
||||
task2id = {"wikisql":0,"squad":1,"srl":2,"sst":3,"woz.en":4,"iwslt.en.de":5,"cnn_dailymail":6,"multinli.in.out":7,"zre":8}
|
||||
task2format={"wikisql":1,"squad":0,"srl":0,"sst":2,"woz.en":1,"iwslt.en.de":0,"cnn_dailymail":1,"multinli.in.out":2,"zre":0}
|
||||
#STOPWORDS = stopwords.words("english")
|
||||
#STOPWORDS = [stopword + " " for stopword in STOPWORDS]
|
||||
#nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
def chuli_example(examples,fname):
|
||||
cur_dataset = fname
|
||||
|
||||
answers_start = []
|
||||
if fname in ["none"]:
|
||||
addition_answers = open(fname+"_answers.json",'r')
|
||||
faddition = json.load(addition_answers)
|
||||
for item in faddition:
|
||||
answers_start.append({"text":item,"answer_start":[0]*len(item)})
|
||||
else:
|
||||
for item in examples["answer"]:
|
||||
answers_start.append({"text":item,"answer_start":[0]*len(item)})
|
||||
print(answers_start[:10])
|
||||
examples = examples.add_column("answers",answers_start)
|
||||
return examples
|
||||
|
||||
|
||||
def preprocess_plain(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column:str):
|
||||
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
|
||||
inputs = [SimpleQAInput.qg_input(question, context) for question,context in zip(questions,contexts)]
|
||||
answers = [_[0] for _ in answers]
|
||||
return inputs,answers
|
||||
|
||||
|
||||
def preprocess_proqa(
|
||||
examples,
|
||||
question_column: str,
|
||||
context_column: str,
|
||||
answer_column:str):
|
||||
|
||||
questions = examples[question_column]
|
||||
contexts = examples[context_column]
|
||||
answers = examples[answer_column]
|
||||
|
||||
inputs = [StructuralQAInput.qg_input(question, context) for question,context in zip(questions,contexts)]
|
||||
answers = [_[0] for _ in answers]
|
||||
return inputs,answers
|
||||
|
||||
|
||||
|
||||
|
||||
def preprocess_function(examples):
|
||||
preprocess_fn = preprocess_proqa#dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, "question","context","answer")
|
||||
model_inputs = tokenizer(inputs, max_length=1024, padding=False, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(targets, max_length=128, padding=False, truncation=True)
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
gen_prompt_ids = [-(i+1) for i in range(1520,1520+10)]
|
||||
format_id = task2format[dataset_name]
|
||||
format_prompt_id_start = 300
|
||||
format_prompt_ids = [-(i+1) for i in range(format_prompt_id_start + format_id * 10,
|
||||
format_prompt_id_start + (format_id + 1) * 10)]
|
||||
task_id = task2id[dataset_name]
|
||||
task_prompt_id_start = 0
|
||||
task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * 20,
|
||||
task_prompt_id_start + (task_id + 1) * 20)]
|
||||
|
||||
|
||||
domain_prompt_id_start = 20*30
|
||||
domain_prompt_number = 20
|
||||
domain_prompt_ids = [- (i + 1) for i in range(domain_prompt_id_start,
|
||||
domain_prompt_id_start + 20)]*5
|
||||
input_ids = copy.deepcopy(
|
||||
[gen_prompt_ids+format_prompt_ids+task_prompt_ids + domain_prompt_ids+input_ids for input_ids in model_inputs['input_ids']])
|
||||
model_inputs['input_ids'] = input_ids # [format_prompt_ids+input_ids for input_ids in model_inputs['input_ids']]
|
||||
model_inputs['attention_mask'] = [[1] * 140 + attention_mask for attention_mask in
|
||||
model_inputs['attention_mask']]
|
||||
return model_inputs
|
||||
|
||||
def preprocess_function_valid(examples):
|
||||
preprocess_fn = preprocess_proqa#dataset_name_to_func(data_args.dataset_name)
|
||||
inputs, targets = preprocess_fn(examples, "question","context","answer")
|
||||
model_inputs = tokenizer(inputs, max_length=1024, padding=False, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(targets, max_length=128, padding=False, truncation=True)
|
||||
model_inputs["example_id"] = []
|
||||
|
||||
for i in range(len(model_inputs["input_ids"])):
|
||||
# One example can give several spans, this is the index of the example containing this span of text.
|
||||
sample_index = i #sample_mapping[i]
|
||||
model_inputs["example_id"].append(examples["id"][sample_index])
|
||||
|
||||
|
||||
|
||||
gen_prompt_ids = [-(i+1) for i in range(1520,1520+10)]
|
||||
format_id = task2format[dataset_name]
|
||||
format_prompt_id_start = 300
|
||||
format_prompt_ids = [-(i+1) for i in range(format_prompt_id_start + format_id * 10,
|
||||
format_prompt_id_start + (format_id + 1) * 10)]
|
||||
task_id = task2id[dataset_name]
|
||||
task_prompt_id_start = 0
|
||||
task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * 20,
|
||||
task_prompt_id_start + (task_id + 1) * 20)]
|
||||
|
||||
|
||||
domain_prompt_id_start = 20*30
|
||||
domain_prompt_number = 20
|
||||
domain_prompt_ids = [- (i + 1) for i in range(domain_prompt_id_start,
|
||||
domain_prompt_id_start + 20)]*5
|
||||
input_ids = copy.deepcopy(
|
||||
[gen_prompt_ids+format_prompt_ids+task_prompt_ids + domain_prompt_ids+input_ids for input_ids in model_inputs['input_ids']])
|
||||
model_inputs['input_ids'] = input_ids # [format_prompt_ids+input_ids for input_ids in model_inputs['input_ids']]
|
||||
model_inputs['attention_mask'] = [[1] * 140 + attention_mask for attention_mask in
|
||||
model_inputs['attention_mask']]
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
|
||||
return model_inputs
|
||||
|
||||
def add_id(example,index):
|
||||
example.update({'id':index})
|
||||
return example
|
||||
|
||||
|
||||
|
||||
def prep(raw_ds,fname):
|
||||
if True:
|
||||
column_names = raw_ds.column_names
|
||||
global dataset_name
|
||||
dataset_name = fname
|
||||
train_dataset = raw_ds.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=4,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=True,
|
||||
desc="Running tokenizer on train dataset",
|
||||
)
|
||||
train_dataset = train_dataset.add_column("id",range(len(train_dataset)))
|
||||
train_dataset.save_to_disk("./ours/{}-train.hf".format(fname))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
import copy
|
||||
def prep_valid(raw_ds,fname):
|
||||
global dataset_name
|
||||
dataset_name = fname
|
||||
eval_examples = copy.deepcopy(raw_ds)
|
||||
eval_examples = chuli_example(eval_examples,fname)
|
||||
column_names = raw_ds.column_names
|
||||
if 'id' not in eval_examples.features.keys():
|
||||
eval_examples = eval_examples.map(add_id,with_indices=True)
|
||||
|
||||
|
||||
if True:
|
||||
eval_dataset = raw_ds.map(add_id,with_indices=True)
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function_valid,
|
||||
batched=True,
|
||||
num_proc=4,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=True,
|
||||
desc="Running tokenizer on validation dataset",
|
||||
)
|
||||
eval_dataset.save_to_disk("./ours/{}-eval.hf".format(fname))
|
||||
eval_examples.save_to_disk("./ours/{}-evalex.hf".format(fname))
|
||||
|
||||
|
||||
from collections import Counter
|
||||
import json
|
||||
|
||||
|
||||
def main():
|
||||
for item in ["wikisql","squad","srl","sst","woz.en","iwslt.en.de","cnn_dailymail","multinli.in.out","zre"]:
|
||||
print("current_dataset:",item)
|
||||
print("Loading......")
|
||||
data_files = {}
|
||||
|
||||
data_files["validation"] = item+"-eval.jsonl"
|
||||
test_dataset = load_dataset("json", data_files=data_files)["validation"]
|
||||
if not item in ["cnn_dailymail","multinli.in.out","zre"]:
|
||||
data_files = {}
|
||||
data_files["train"] = item+"-train.jsonl"
|
||||
train_dataset = load_dataset("json", data_files=data_files)["train"]
|
||||
|
||||
|
||||
|
||||
print("Loaded.")
|
||||
if not item in ["cnn_dailymail","multinli.in.out","zre"]:
|
||||
prep(train_dataset,item)
|
||||
print("preped")
|
||||
prep_valid(test_dataset,item)
|
||||
print("valid prepred")
|
||||
|
||||
|
||||
|
||||
main()
|
|
@ -0,0 +1,477 @@
|
|||
import collections
|
||||
import string
|
||||
import re
|
||||
import numpy as np
|
||||
import json
|
||||
AGG_OPS = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
|
||||
COND_OPS = ['=', '>', '<', 'OP']
|
||||
COND_OPS = ['=', '>', '<']
|
||||
#from pyrouge import Rouge155
|
||||
#from multiprocessing import Pool, cpu_count
|
||||
#from contextlib import closing
|
||||
from datasets import load_metric
|
||||
|
||||
def compute_rouge_scores(summs, refs, splitchar='.', options=None, parallel=True):
|
||||
assert len(summs) == len(refs)
|
||||
options = [
|
||||
'-a', # evaluate all systems
|
||||
'-c', 95, # confidence interval
|
||||
'-m', # use Porter stemmer
|
||||
'-n', 2, # max-ngram
|
||||
'-w', 1.3, # weight (weighting factor for WLCS)
|
||||
]
|
||||
rr = Rouge(options=options)
|
||||
rouge_args = []
|
||||
for summ, ref in zip(summs, refs):
|
||||
letter = "A"
|
||||
ref_dict = {}
|
||||
for r in ref:
|
||||
ref_dict[letter] = [x for x in split_sentences(r, splitchar) if len(x) > 0]
|
||||
letter = chr(ord(letter) + 1)
|
||||
s = [x for x in split_sentences(summ, splitchar) if len(x) > 0]
|
||||
rouge_args.append((s, ref_dict))
|
||||
if parallel:
|
||||
with closing(Pool(cpu_count()//2)) as pool:
|
||||
rouge_scores = pool.starmap(rr.score_summary, rouge_args)
|
||||
else:
|
||||
rouge_scores = []
|
||||
for s, a in rouge_args:
|
||||
rouge_scores.append(rr.score_summary(s, ref_dict))
|
||||
return rouge_scores
|
||||
|
||||
def computeROUGE(greedy, answer):
|
||||
rouges = compute_rouge_scores(greedy, answer)
|
||||
if len(rouges) > 0:
|
||||
avg_rouges = {}
|
||||
for key in rouges[0].keys():
|
||||
avg_rouges[key] = sum(
|
||||
[r.get(key, 0.0) for r in rouges]) / len(rouges) * 100
|
||||
else:
|
||||
avg_rouges = None
|
||||
return avg_rouges
|
||||
|
||||
def lcsstr(string1, string2):
|
||||
answer = 0
|
||||
len1, len2 = len(string1), len(string2)
|
||||
for i in range(len1):
|
||||
match = 0
|
||||
for j in range(len2):
|
||||
if (i + j < len1 and string1[i + j] == string2[j]):
|
||||
match += 1
|
||||
if match > answer:
|
||||
answer = match
|
||||
else:
|
||||
match = 0
|
||||
return answer
|
||||
|
||||
|
||||
def lcs(X, Y):
|
||||
m = len(X)
|
||||
n = len(Y)
|
||||
L = [[None]*(n + 1) for i in range(m + 1)]
|
||||
|
||||
for i in range(m + 1):
|
||||
for j in range(n + 1):
|
||||
if i == 0 or j == 0 :
|
||||
L[i][j] = 0
|
||||
elif X[i-1] == Y[j-1]:
|
||||
L[i][j] = L[i-1][j-1]+1
|
||||
else:
|
||||
L[i][j] = max(L[i-1][j], L[i][j-1])
|
||||
return L[m][n] / max(m, n) + lcsstr(X, Y) / 1e4 - min(m, n) / 1e8
|
||||
|
||||
|
||||
def normalize_text(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
def remove_articles(text):
|
||||
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
||||
def white_space_fix(text):
|
||||
# print(text)
|
||||
# print("after:",' '.join(text.split( )))
|
||||
return ' '.join(text.split( ))
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
# print(white_space_fix(remove_articles(remove_punc(lower(s)))))
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def to_delta_state(line):
|
||||
delta_state = {'inform': {}, 'request': {}}
|
||||
try:
|
||||
if line.lower() == 'none' or line.strip() == '' or line.strip() == ';':
|
||||
return delta_state
|
||||
inform, request = [[y.strip() for y in x.strip().split(',')] for x in line.split(';')]
|
||||
inform_pairs = {}
|
||||
for i in inform:
|
||||
try:
|
||||
k, v = i.split(':')
|
||||
inform_pairs[k.strip()] = v.strip()
|
||||
except:
|
||||
pass
|
||||
delta_state = {'inform': inform_pairs, 'request': request}
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
return delta_state
|
||||
|
||||
|
||||
def update_state(state, delta):
|
||||
for act, slot in delta.items():
|
||||
state[act] = slot
|
||||
return state
|
||||
|
||||
|
||||
def dict_cmp(d1, d2):
|
||||
def cmp(a, b):
|
||||
for k1, v1 in a.items():
|
||||
if k1 not in b:
|
||||
return False
|
||||
else:
|
||||
if v1 != b[k1]:
|
||||
return False
|
||||
return True
|
||||
return cmp(d1, d2) and cmp(d2, d1)
|
||||
|
||||
|
||||
def to_lf(s, table):
|
||||
# print("input:",s)
|
||||
aggs = [y.lower() for y in AGG_OPS]
|
||||
agg_to_idx = {x: i for i, x in enumerate(aggs)}
|
||||
conditionals = [y.lower() for y in COND_OPS]
|
||||
headers_unsorted = [(y.lower(), i) for i, y in enumerate(table['header'])]
|
||||
headers = [(y.lower(), i) for i, y in enumerate(table['header'])]
|
||||
headers.sort(reverse=True, key=lambda x: len(x[0]))
|
||||
condition_s, conds = None, []
|
||||
if 'where' in s:
|
||||
s, condition_s = s.split('where', 1)
|
||||
# print("s,cond",s,condition_s)
|
||||
s = ' '.join(s.split()[1:-2])
|
||||
s_no_agg = ' '.join(s.split()[1:])
|
||||
# print(s,s_no_agg,"agg")
|
||||
sel, agg = None, 0
|
||||
lcss, idxs = [], []
|
||||
for col, idx in headers:
|
||||
lcss.append(lcs(col, s))
|
||||
lcss.append(lcs(col, s_no_agg))
|
||||
idxs.append(idx)
|
||||
lcss = np.array(lcss)
|
||||
max_id = np.argmax(lcss)
|
||||
sel = idxs[max_id // 2]
|
||||
if max_id % 2 == 1: # with agg
|
||||
agg = agg_to_idx[s.split()[0]]
|
||||
|
||||
full_conditions = []
|
||||
if not condition_s is None:
|
||||
|
||||
pattern = '|'.join(COND_OPS)
|
||||
split_conds_raw = re.split(pattern, condition_s)
|
||||
split_conds_raw = [conds.strip() for conds in split_conds_raw]
|
||||
split_conds = [split_conds_raw[0]]
|
||||
for i in range(1, len(split_conds_raw)-1):
|
||||
split_conds.extend(re.split('and', split_conds_raw[i]))
|
||||
split_conds += [split_conds_raw[-1]]
|
||||
for i in range(0, len(split_conds), 2):
|
||||
cur_s = split_conds[i]
|
||||
lcss = []
|
||||
for col in headers:
|
||||
lcss.append(lcs(col[0], cur_s))
|
||||
max_id = np.argmax(np.array(lcss))
|
||||
split_conds[i] = headers[max_id][0]
|
||||
for i, m in enumerate(re.finditer(pattern, condition_s)):
|
||||
split_conds[2*i] = split_conds[2*i] + ' ' + m.group()
|
||||
split_conds = [' '.join(split_conds[2*i:2*i+2]) for i in range(len(split_conds)//2)]
|
||||
|
||||
condition_s = ' and '.join(split_conds)
|
||||
condition_s = ' ' + condition_s + ' '
|
||||
for idx, col in enumerate(headers):
|
||||
condition_s = condition_s.replace(' ' + col[0] + ' ', ' Col{} '.format(col[1]))
|
||||
condition_s = condition_s.strip()
|
||||
|
||||
for idx, col in enumerate(conditionals):
|
||||
new_s = []
|
||||
for t in condition_s.split():
|
||||
if t == col:
|
||||
new_s.append('Cond{}'.format(idx))
|
||||
else:
|
||||
new_s.append(t)
|
||||
condition_s = ' '.join(new_s)
|
||||
s = condition_s
|
||||
|
||||
conds = re.split('(Col\d+ Cond\d+)', s)
|
||||
if len(conds) == 0:
|
||||
conds = [s]
|
||||
conds = [x for x in conds if len(x.strip()) > 0]
|
||||
full_conditions = []
|
||||
for i, x in enumerate(conds):
|
||||
if i % 2 == 0:
|
||||
x = x.split()
|
||||
col_num = int(x[0].replace('Col', ''))
|
||||
opp_num = int(x[1].replace('Cond', ''))
|
||||
full_conditions.append([col_num, opp_num])
|
||||
else:
|
||||
x = x.split()
|
||||
if x[-1] == 'and':
|
||||
x = x[:-1]
|
||||
x = ' '.join(x)
|
||||
if 'Col' in x:
|
||||
new_x = []
|
||||
for t in x.split():
|
||||
if 'Col' in t:
|
||||
idx = int(t.replace('Col', ''))
|
||||
t = headers_unsorted[idx][0]
|
||||
new_x.append(t)
|
||||
x = new_x
|
||||
x = ' '.join(x)
|
||||
if 'Cond' in x:
|
||||
new_x = []
|
||||
for t in x.split():
|
||||
if 'Cond' in t:
|
||||
idx = int(t.replace('Cond', ''))
|
||||
t = conditionals[idx]
|
||||
new_x.append(t)
|
||||
x = new_x
|
||||
x = ' '.join(x)
|
||||
full_conditions[-1].append(x.replace(' ', ''))
|
||||
logical_form = {'sel': sel, 'conds': full_conditions, 'agg': agg}
|
||||
return logical_form
|
||||
|
||||
|
||||
def computeLFEM(greedy, answer):
|
||||
count = 0
|
||||
correct = 0
|
||||
text_answers = []
|
||||
for idx, (g, ex) in enumerate(zip(greedy, answer)):
|
||||
count += 1
|
||||
|
||||
text_answers.append([ex['answer'].lower()])
|
||||
try:
|
||||
gt = ex['sql']
|
||||
conds = gt['conds']
|
||||
lower_conds = []
|
||||
for c in conds:
|
||||
lc = c
|
||||
lc[2] = str(lc[2]).lower().replace(' ', '')
|
||||
lower_conds.append(lc)
|
||||
gt['conds'] = lower_conds
|
||||
# print("gt_answer:",ex['answer'].lower())
|
||||
lf = to_lf(g, ex['table'])
|
||||
# print(lf,"lf")
|
||||
# print(gt,"gt")
|
||||
correct += lf == gt
|
||||
|
||||
except Exception as e:
|
||||
continue
|
||||
return (correct / count) * 100, text_answers
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
scores_for_ground_truths = []
|
||||
for ground_truth in ground_truths:
|
||||
score = metric_fn(prediction, ground_truth)
|
||||
scores_for_ground_truths.append(score)
|
||||
return max(scores_for_ground_truths)
|
||||
|
||||
|
||||
def f1_score(prediction, ground_truth):
|
||||
prediction_tokens = prediction.split()
|
||||
ground_truth_tokens = ground_truth.split()
|
||||
common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def exact_match(prediction, ground_truth):
|
||||
return prediction == ground_truth
|
||||
|
||||
|
||||
def computeF1(outputs, targets):
|
||||
return sum([metric_max_over_ground_truths(f1_score, o, t) for o, t in zip(outputs, targets)]) / len(outputs) * 100
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
scores_for_ground_truths = []
|
||||
for idx, ground_truth in enumerate(ground_truths):
|
||||
score = metric_fn(prediction, ground_truth)
|
||||
scores_for_ground_truths.append(score)
|
||||
return max(scores_for_ground_truths)
|
||||
|
||||
|
||||
def computeEM(outputs, targets):
|
||||
outs = [metric_max_over_ground_truths(exact_match, o, t) for o, t in zip(outputs, targets)]
|
||||
return sum(outs) / len(outputs) * 100
|
||||
|
||||
|
||||
def fix_buggy_characters(str):
|
||||
return re.sub("[{}^\\\\`\u2047<]", " ", str)
|
||||
def replace_punctuation(str):
|
||||
return str.replace("\"", "").replace("'", "")
|
||||
def score_string_similarity(str1, str2):
|
||||
if str1 == str2:
|
||||
return 3.0 # Better than perfect token match
|
||||
str1 = fix_buggy_characters(replace_punctuation(str1))
|
||||
str2 = fix_buggy_characters(replace_punctuation(str2))
|
||||
if str1 == str2:
|
||||
return 2.0
|
||||
if " " in str1 or " " in str2:
|
||||
str1_split = str1.split(" ")
|
||||
str2_split = str2.split(" ")
|
||||
overlap = list(set(str1_split) & set(str2_split))
|
||||
return len(overlap) / max(len(str1_split), len(str2_split))
|
||||
else:
|
||||
if str1 == str2:
|
||||
return 1.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def computeEMtri(outputs, targets):
|
||||
options = ["entailment","neutral","contradiction"]
|
||||
preds = []
|
||||
for o in outputs:
|
||||
scores = [score_string_similarity(opt, o) for opt in options]
|
||||
max_idx = np.argmax(scores)
|
||||
preds.append(options[max_idx])
|
||||
print(preds)
|
||||
print(targets)
|
||||
outs = [metric_max_over_ground_truths(exact_match, o, t) for o, t in zip(preds, targets)]
|
||||
return sum(outs) / len(outputs) * 100
|
||||
|
||||
def score(answer, gold):
|
||||
if len(gold) > 0:
|
||||
gold = set.union(*[simplify(g) for g in gold])
|
||||
answer = simplify(answer)
|
||||
tp, tn, sys_pos, real_pos = 0, 0, 0, 0
|
||||
if answer == gold:
|
||||
if not ('unanswerable' in gold and len(gold) == 1):
|
||||
tp += 1
|
||||
else:
|
||||
tn += 1
|
||||
if not ('unanswerable' in answer and len(answer) == 1):
|
||||
sys_pos += 1
|
||||
if not ('unanswerable' in gold and len(gold) == 1):
|
||||
real_pos += 1
|
||||
return np.array([tp, tn, sys_pos, real_pos])
|
||||
|
||||
|
||||
def simplify(answer):
|
||||
return set(''.join(c for c in t if c not in string.punctuation) for t in answer.strip().lower().split()) - {'the', 'a', 'an', 'and', ''}
|
||||
|
||||
|
||||
def computeCF1(greedy, answer):
|
||||
scores = np.zeros(4)
|
||||
for g, a in zip(greedy, answer):
|
||||
scores += score(g, a)
|
||||
tp, tn, sys_pos, real_pos = scores.tolist()
|
||||
total = len(answer)
|
||||
if tp == 0:
|
||||
p = r = f = 0.0
|
||||
else:
|
||||
p = tp / float(sys_pos)
|
||||
r = tp / float(real_pos)
|
||||
f = 2 * p * r / (p + r)
|
||||
return f * 100, p * 100, r * 100
|
||||
|
||||
|
||||
def computeDialogue(greedy, answer):
|
||||
examples = []
|
||||
for idx, (g, a) in enumerate(zip(greedy, answer)):
|
||||
examples.append((a[0], g, a[1], idx))
|
||||
#examples.sort()
|
||||
turn_request_positives = 0
|
||||
turn_goal_positives = 0
|
||||
joint_goal_positives = 0
|
||||
ldt = None
|
||||
for ex in examples:
|
||||
if ldt is None or ldt.split('_')[:-1] != ex[0].split('_')[:-1]:
|
||||
state, answer_state = {}, {}
|
||||
ldt = ex[0]
|
||||
delta_state = to_delta_state(ex[1])
|
||||
answer_delta_state = to_delta_state(ex[2])
|
||||
state = update_state(state, delta_state['inform'])
|
||||
answer_state = update_state(answer_state, answer_delta_state['inform'])
|
||||
if dict_cmp(state, answer_state):
|
||||
joint_goal_positives += 1
|
||||
if delta_state['request'] == answer_delta_state['request']:
|
||||
turn_request_positives += 1
|
||||
if dict_cmp(delta_state['inform'], answer_delta_state['inform']):
|
||||
turn_goal_positives += 1
|
||||
|
||||
joint_goal_em = joint_goal_positives / len(examples) * 100
|
||||
turn_request_em = turn_request_positives / len(examples) * 100
|
||||
turn_goal_em = turn_goal_positives / len(examples) * 100
|
||||
answer = [(x[-1], x[-2]) for x in examples]
|
||||
#answer.sort()
|
||||
answer = [[x[1]] for x in answer]
|
||||
return joint_goal_em, turn_request_em, turn_goal_em, answer
|
||||
|
||||
|
||||
def compute_metrics(data, rouge=False, bleu=False, corpus_f1=False, logical_form=False, dialogue=False,tri=False):
|
||||
if rouge:
|
||||
metric_func = load_metric("metric/rouge_local/rouge_metric.py")
|
||||
metrics = metric_func.compute(predictions=data.predictions, references=data.label_ids)
|
||||
metric_keys = ["rougeL"]
|
||||
metric_values = metrics["rougeL"]
|
||||
metric_dict = collections.OrderedDict(list(zip(metric_keys, [metric_values])))
|
||||
return metric_dict
|
||||
|
||||
greedy = data.predictions
|
||||
answer = data.label_ids
|
||||
greedy = [_["prediction_text"] for _ in greedy]
|
||||
# greedy = [_["answers"]["text"][0].lower() for _ in answer]
|
||||
answer = [_["answers"]["text"] for _ in answer]
|
||||
|
||||
|
||||
if dialogue:
|
||||
|
||||
addition_answers = open("./oursdeca/"+"woz.en_answers.json",'r')
|
||||
answer = json.load(addition_answers)
|
||||
if logical_form:
|
||||
addition_answers = open("./oursdeca/"+"wikisql_answers.json",'r')
|
||||
answer = json.load(addition_answers)
|
||||
|
||||
|
||||
|
||||
metric_keys = []
|
||||
metric_values = []
|
||||
if logical_form:
|
||||
lfem, answer = computeLFEM(greedy, answer)
|
||||
metric_keys += ['lfem']
|
||||
metric_values += [lfem]
|
||||
if tri:
|
||||
em = computeEMtri(greedy,answer)
|
||||
else:
|
||||
em = computeEM(greedy, answer)
|
||||
print(greedy[:20])
|
||||
print(answer[:20])
|
||||
metric_keys.append('em')
|
||||
metric_values.append(em)
|
||||
norm_greedy = [normalize_text(g) for g in greedy]
|
||||
norm_answer = [[normalize_text(a) for a in ans] for ans in answer]
|
||||
|
||||
nf1 = computeF1(norm_greedy, norm_answer)
|
||||
nem = computeEM(norm_greedy, norm_answer)
|
||||
metric_keys.extend(['nf1', 'nem'])
|
||||
metric_values.extend([nf1, nem])
|
||||
if rouge:
|
||||
rouge = computeROUGE(greedy, answer)
|
||||
metric_keys += ['rouge1', 'rouge2', 'rougeL', 'avg_rouge']
|
||||
avg_rouge = (rouge['rouge_1_f_score'] + rouge['rouge_2_f_score'] + rouge['rouge_l_f_score']) / 3
|
||||
metric_values += [rouge['rouge_1_f_score'], rouge['rouge_2_f_score'], rouge['rouge_l_f_score'], avg_rouge]
|
||||
if corpus_f1:
|
||||
corpus_f1, precision, recall = computeCF1(norm_greedy, norm_answer)
|
||||
metric_keys += ['corpus_f1', 'precision', 'recall']
|
||||
metric_values += [corpus_f1, precision, recall]
|
||||
if dialogue:
|
||||
joint_goal_em, request_em, turn_goal_em, answer = computeDialogue(greedy, answer)
|
||||
avg_dialogue = (joint_goal_em + request_em) / 2
|
||||
metric_keys += ['joint_goal_em', 'turn_request_em', 'turn_goal_em', 'avg_dialogue']
|
||||
metric_values += [joint_goal_em, request_em, turn_goal_em, avg_dialogue]
|
||||
metric_dict = collections.OrderedDict(list(zip(metric_keys, metric_values)))
|
||||
return metric_dict
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,61 @@
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import torch
|
||||
import copy
|
||||
#adaptive decision boundary
|
||||
#[code]https://github.com/thuiar/Adaptive-Decision-Boundary
|
||||
#[paper]https://www.aaai.org/AAAI21Papers/AAAI-9723.ZhangH.pdf
|
||||
def euclidean_metric(a, b):
|
||||
n = a.shape[0]
|
||||
m = b.shape[0]
|
||||
a = a.unsqueeze(1).expand(n, m, -1)
|
||||
b = b.unsqueeze(0).expand(n, m, -1)
|
||||
logits = -((a - b)**2).sum(dim=2)
|
||||
return logits
|
||||
|
||||
def cosine_metric(a,b):
|
||||
n = a.shape[0]
|
||||
m = b.shape[0]
|
||||
a = a.unsqueeze(1).expand(n, m, -1)
|
||||
b = b.unsqueeze(0).expand(n, m, -1)
|
||||
logits = (a*b).sum(dim=2)
|
||||
# logits = -logits+1
|
||||
return logits
|
||||
|
||||
|
||||
class BoundaryLoss(nn.Module):
|
||||
|
||||
def __init__(self, num_labels=10, feat_dim=2):
|
||||
|
||||
super(BoundaryLoss, self).__init__()
|
||||
self.num_labels = num_labels
|
||||
self.feat_dim = feat_dim
|
||||
|
||||
self.delta = nn.Parameter(torch.ones(20).cuda())
|
||||
nn.init.normal_(self.delta)
|
||||
|
||||
|
||||
def forward(self, pooled_output, centroids_, labels,group=0):
|
||||
centroids = copy.deepcopy(centroids_.detach())
|
||||
logits = cosine_metric(pooled_output, centroids)
|
||||
probs, preds = F.softmax(logits.detach(), dim=1).max(dim=1)
|
||||
delta = F.softplus(self.delta)
|
||||
#c = torch.Tensor([[0.0]*768]*(centroids[labels]).size(0)).to(pooled_output.device)#
|
||||
# print(self.delta.size())
|
||||
# print("labels",labels.size())
|
||||
c = centroids[labels]
|
||||
c = c/torch.norm(c, p=2, dim=1, keepdim=True)
|
||||
# c.requires_grad = False
|
||||
d = delta[labels]
|
||||
x = pooled_output
|
||||
|
||||
euc_dis = 1-((c*x).sum(dim=1))#torch.norm(x - c,2, 1).view(-1)
|
||||
pos_mask = (euc_dis > d).type(torch.cuda.FloatTensor)
|
||||
neg_mask = (euc_dis < d).type(torch.cuda.FloatTensor)
|
||||
|
||||
pos_loss = (euc_dis - d) * pos_mask
|
||||
neg_loss = (d - euc_dis) * neg_mask
|
||||
loss = pos_loss.mean() + neg_loss.mean()
|
||||
|
||||
|
||||
return loss, delta
|
|
@ -0,0 +1,56 @@
|
|||
absl-py==1.0.0
|
||||
aiohttp==3.8.1
|
||||
aiosignal==1.2.0
|
||||
async-timeout==4.0.2
|
||||
asynctest==0.13.0
|
||||
attrs==21.4.0
|
||||
certifi==2022.5.18.1
|
||||
charset-normalizer==2.0.12
|
||||
click==8.1.3
|
||||
datasets==2.2.2
|
||||
dill==0.3.4
|
||||
filelock==3.7.0
|
||||
frozenlist==1.3.0
|
||||
fsspec==2022.5.0
|
||||
future==0.18.2
|
||||
fuzzywuzzy==0.18.0
|
||||
huggingface-hub==0.7.0
|
||||
idna==3.3
|
||||
importlib-metadata==4.11.4
|
||||
joblib==1.1.0
|
||||
multidict==6.0.2
|
||||
multiprocess==0.70.12.2
|
||||
nlp==0.4.0
|
||||
nltk==3.7
|
||||
numpy==1.21.6
|
||||
packaging==21.3
|
||||
pandas==1.1.5
|
||||
pyarrow==8.0.0
|
||||
pyparsing==3.0.9
|
||||
python-dateutil==2.8.2
|
||||
pytz==2022.1
|
||||
PyYAML==6.0
|
||||
regex==2022.4.24
|
||||
requests==2.27.1
|
||||
responses==0.18.0
|
||||
rouge-score==0.0.4
|
||||
sacremoses==0.0.53
|
||||
scikit-learn==1.0.2
|
||||
scipy==1.7.3
|
||||
six==1.16.0
|
||||
sklearn==0.0
|
||||
threadpoolctl==3.1.0
|
||||
tokenizers==0.10.3
|
||||
tqdm==4.64.0
|
||||
transformers==4.12.5
|
||||
typing_extensions==4.2.0
|
||||
urllib3==1.26.9
|
||||
wget==3.2
|
||||
xxhash==3.0.0
|
||||
yarl==1.7.2
|
||||
zipp==3.8.0
|
||||
numpy==1.20.3
|
||||
rouge_score
|
||||
omegaconf
|
||||
jsonlines
|
||||
spacy
|
|
@ -0,0 +1,55 @@
|
|||
# Diana
|
||||
The PyTorch implementation of paper [Domain Incremental Lifelong Learning in an Open World](https://arxiv.org/abs/2305.06555) (ACL 2023)
|
||||
|
||||
## Requirements
|
||||
```bash
|
||||
cd DomainIncrementalLL
|
||||
pip install -r r.txt
|
||||
```
|
||||
## QA tasks
|
||||
### Data preparation
|
||||
Download datasets, unzip and place in /data_process
|
||||
|
||||
[QA Dataset](https://drive.google.com/file/d/1x8vvrdfwKCXpT_M4NA_eso8cX-LlMWsQ/view?usp=share_link)
|
||||
|
||||
Tokenize plain text datasets:
|
||||
```python
|
||||
cd downstream
|
||||
python create_l2p.py --model_name_or_path t5-base --dataset_name null --output_dir null
|
||||
```
|
||||
|
||||
### Train and evaluate on 4 GPUs
|
||||
```python
|
||||
bash ./diana.sh ./ll 8 48
|
||||
```
|
||||
|
||||
### ablations: No task prompt & No meta prompt
|
||||
|
||||
To evaluate ablation without task prompts:
|
||||
```bash
|
||||
bash ./dianawotask.sh ./ll 8 48
|
||||
```
|
||||
To evaluate ablation without meta prompts:
|
||||
```bash
|
||||
bash ./dianawometa.sh ./ll 8 48
|
||||
```
|
||||
|
||||
## DecaNLP tasks
|
||||
|
||||
[Dataset](https://drive.google.com/file/d/1pzIoTzooaecsU4HXP_n4-_7Uuxv9i4A-/view?usp=share_link)
|
||||
|
||||
extract plain texts from raw files:
|
||||
```bash
|
||||
cd downstreamdeca
|
||||
python extract.py
|
||||
```
|
||||
tokenize plain texts:
|
||||
```bash
|
||||
python toknize.py
|
||||
```
|
||||
|
||||
|
||||
Run the training and evaluation script:
|
||||
```bash
|
||||
bash ./diana.sh ./ll 8 48
|
||||
```
|
Loading…
Reference in New Issue