From 99c2ba11953689ee22b8883ebe45e6a35df14e0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=87=BA=E8=9B=B0?= Date: Thu, 18 May 2023 20:12:38 +0800 Subject: [PATCH] add: diana --- diana/data_process/QAInput.py | 69 + diana/data_process/QGInput.py | 88 + diana/data_process/tasks.py | 156 ++ diana/downstream/create_l2p.py | 782 ++++++ diana/downstream/dataset_processors.py | 192 ++ diana/downstream/diana-uni.sh | 53 + diana/downstream/diana.py | 827 ++++++ diana/downstream/diana.sh | 53 + diana/downstream/dianawometa.py | 798 ++++++ diana/downstream/dianawometa.sh | 53 + diana/downstream/dianawotask.py | 827 ++++++ diana/downstream/dianawotask.sh | 53 + diana/downstream/l2ptrainer.py | 556 ++++ diana/downstream/metric/accuracy.py | 105 + diana/downstream/metric/bleu.py | 126 + diana/downstream/metric/bleu_local/bleu.py | 219 ++ .../metric/rouge_local/rouge_metric.py | 133 + .../metric/rouge_local/rouge_score.py | 318 +++ .../metric/rouge_local/rouge_tokenizers.py | 90 + .../downstream/metric/rouge_local/scoring.py | 158 ++ .../metric/squad_v1_local/evaluate.py | 92 + .../metric/squad_v1_local/squad_v1_local.py | 109 + .../metric/squad_v2_local/evaluate.py | 330 +++ .../metric/squad_v2_local/squad_v2_local.py | 139 + diana/downstream/simple_processors.py | 192 ++ diana/downstreamdeca/QAInput.py | 21 + diana/downstreamdeca/dataset_processors.py | 192 ++ diana/downstreamdeca/diana.py | 801 ++++++ diana/downstreamdeca/diana.sh | 53 + diana/downstreamdeca/diana_log.txt | 6 + diana/downstreamdeca/dianawometa.py | 775 ++++++ diana/downstreamdeca/dianawometa.sh | 53 + diana/downstreamdeca/dianawotask.py | 801 ++++++ diana/downstreamdeca/dianawotask.sh | 53 + diana/downstreamdeca/extract.py | 169 ++ diana/downstreamdeca/l2ptrainer.py | 557 ++++ diana/downstreamdeca/metric/accuracy.py | 105 + diana/downstreamdeca/metric/bleu.py | 126 + .../downstreamdeca/metric/bleu_local/bleu.py | 219 ++ .../metric/rouge_local/rouge_metric.py | 133 + .../metric/rouge_local/rouge_score.py | 318 +++ .../metric/rouge_local/rouge_tokenizers.py | 90 + .../metric/rouge_local/scoring.py | 158 ++ .../metric/squad_v1_local/evaluate.py | 92 + .../metric/squad_v1_local/squad_v1_local.py | 109 + .../metric/squad_v2_local/evaluate.py | 330 +++ .../metric/squad_v2_local/squad_v2_local.py | 139 + diana/downstreamdeca/simple_processors.py | 192 ++ diana/downstreamdeca/toknize.py | 248 ++ diana/metrics.py | 477 ++++ diana/models/metadeca.py | 2406 ++++++++++++++++ diana/models/metadecanometa.py | 2386 ++++++++++++++++ diana/models/metadecanotask.py | 2349 ++++++++++++++++ diana/models/metat5.py | 2409 +++++++++++++++++ diana/models/metat5nometa.py | 2390 ++++++++++++++++ diana/models/metat5notask.py | 2351 ++++++++++++++++ diana/models/utils.py | 61 + diana/r.txt | 56 + diana/readme.md | 55 + 59 files changed, 27198 insertions(+) create mode 100755 diana/data_process/QAInput.py create mode 100755 diana/data_process/QGInput.py create mode 100755 diana/data_process/tasks.py create mode 100644 diana/downstream/create_l2p.py create mode 100755 diana/downstream/dataset_processors.py create mode 100644 diana/downstream/diana-uni.sh create mode 100644 diana/downstream/diana.py create mode 100644 diana/downstream/diana.sh create mode 100644 diana/downstream/dianawometa.py create mode 100644 diana/downstream/dianawometa.sh create mode 100644 diana/downstream/dianawotask.py create mode 100644 diana/downstream/dianawotask.sh create mode 100755 diana/downstream/l2ptrainer.py create mode 100644 diana/downstream/metric/accuracy.py create mode 100644 diana/downstream/metric/bleu.py create mode 100644 diana/downstream/metric/bleu_local/bleu.py create mode 100644 diana/downstream/metric/rouge_local/rouge_metric.py create mode 100644 diana/downstream/metric/rouge_local/rouge_score.py create mode 100644 diana/downstream/metric/rouge_local/rouge_tokenizers.py create mode 100644 diana/downstream/metric/rouge_local/scoring.py create mode 100644 diana/downstream/metric/squad_v1_local/evaluate.py create mode 100644 diana/downstream/metric/squad_v1_local/squad_v1_local.py create mode 100644 diana/downstream/metric/squad_v2_local/evaluate.py create mode 100644 diana/downstream/metric/squad_v2_local/squad_v2_local.py create mode 100755 diana/downstream/simple_processors.py create mode 100644 diana/downstreamdeca/QAInput.py create mode 100755 diana/downstreamdeca/dataset_processors.py create mode 100644 diana/downstreamdeca/diana.py create mode 100644 diana/downstreamdeca/diana.sh create mode 100644 diana/downstreamdeca/diana_log.txt create mode 100644 diana/downstreamdeca/dianawometa.py create mode 100644 diana/downstreamdeca/dianawometa.sh create mode 100644 diana/downstreamdeca/dianawotask.py create mode 100644 diana/downstreamdeca/dianawotask.sh create mode 100644 diana/downstreamdeca/extract.py create mode 100755 diana/downstreamdeca/l2ptrainer.py create mode 100644 diana/downstreamdeca/metric/accuracy.py create mode 100644 diana/downstreamdeca/metric/bleu.py create mode 100644 diana/downstreamdeca/metric/bleu_local/bleu.py create mode 100644 diana/downstreamdeca/metric/rouge_local/rouge_metric.py create mode 100644 diana/downstreamdeca/metric/rouge_local/rouge_score.py create mode 100644 diana/downstreamdeca/metric/rouge_local/rouge_tokenizers.py create mode 100644 diana/downstreamdeca/metric/rouge_local/scoring.py create mode 100644 diana/downstreamdeca/metric/squad_v1_local/evaluate.py create mode 100644 diana/downstreamdeca/metric/squad_v1_local/squad_v1_local.py create mode 100644 diana/downstreamdeca/metric/squad_v2_local/evaluate.py create mode 100644 diana/downstreamdeca/metric/squad_v2_local/squad_v2_local.py create mode 100755 diana/downstreamdeca/simple_processors.py create mode 100644 diana/downstreamdeca/toknize.py create mode 100644 diana/metrics.py create mode 100755 diana/models/metadeca.py create mode 100755 diana/models/metadecanometa.py create mode 100755 diana/models/metadecanotask.py create mode 100755 diana/models/metat5.py create mode 100755 diana/models/metat5nometa.py create mode 100755 diana/models/metat5notask.py create mode 100644 diana/models/utils.py create mode 100644 diana/r.txt create mode 100644 diana/readme.md diff --git a/diana/data_process/QAInput.py b/diana/data_process/QAInput.py new file mode 100755 index 0000000..1ce8d40 --- /dev/null +++ b/diana/data_process/QAInput.py @@ -0,0 +1,69 @@ +#ProQA +#[paper]https://arxiv.org/abs/2205.04040 +class StructuralQAInput: + @classmethod + def qg_input_abstrativeqa(cls, context, question, options=None): + question = question + source_text = f'[TASK] [ABSTRACTIVE] [QUESTION] {question}. [CONTEXT] {context} the answer is: ' + return source_text + + @classmethod + def qg_input_boolqa(cls, context, question, options=None): + source_text = f'[TASK] [BOOL] [QUESTION] {question}. [CONTEXT] {context} the answer is: ' + return source_text + + @classmethod + def qg_input_extractive_qa(cls, context, question, options=None): + source_text = f'[TASK] [EXTRACTIVE] [QUESTION] {question.lower().capitalize()}. [CONTEXT] {context} the answer is: ' + return source_text + + @classmethod + def qg_input_multirc(cls, context, question, options=None): + source_text = f'[TASK] [MultiChoice] [QUESTION] {question} [OPTIONS] {options} [CONTEXT] {context} the answer is: ' + return source_text + + + +class LFQAInput: + @classmethod + def qg_input_abstrativeqa(cls, context, question, options=None): + question = question + source_text = f'{question} \\n {context}' + return source_text + + @classmethod + def qg_input_boolqa(cls, context, question, options=None): + source_text = f'{question} \\n {context}' + return source_text + + @classmethod + def qg_input_extractive_qa(cls, context, question, options=None): + source_text = f'{question.lower().capitalize()} \\n {context}' + return source_text + + @classmethod + def qg_input_multirc(cls, context, question, options=None): + source_text = f'{question} \\n {options} \\n {context}' + return source_text + +class SimpleQAInput: + @classmethod + def qg_input_abstrativeqa(cls, context, question, options=None): + question = question + source_text = f'{question} \\n {context}' + return source_text + + @classmethod + def qg_input_boolqa(cls, context, question, options=None): + source_text = f'{question} \\n {context}' + return source_text + + @classmethod + def qg_input_extractive_qa(cls, context, question, options=None): + source_text = f'{question.lower().capitalize()} \\n {context}' + return source_text + + @classmethod + def qg_input_multirc(cls, context, question, options=None): + source_text = f'{question} \\n {options} \\n {context}' + return source_text \ No newline at end of file diff --git a/diana/data_process/QGInput.py b/diana/data_process/QGInput.py new file mode 100755 index 0000000..d35f855 --- /dev/null +++ b/diana/data_process/QGInput.py @@ -0,0 +1,88 @@ +class QGInput: + @classmethod + def qg_input_abstractiveqa(cls,context,question,answers,options=None): + outputs = [] + for ans in answers: + source_text = f'The context: {context}, the answer is: {ans}. This is a summary task, please generate question: ' + target_text = f'{question}' + outputs.append({'source_text':source_text,'target_text':target_text}) + return outputs + @classmethod + def qg_input_abstractiveqa_qapairs(cls,context,question,answers,options=None): + outputs = [] + for ans in answers: + source_text = f'Generate question and the answer: The context: {context}.' + target_text = f'[Question]: {question} [Answer] {ans}' + outputs.append({'source_text': source_text, 'target_text': target_text}) + return outputs + @classmethod + def qg_input_boolqa(cls,context,question,answers,options=None): + outputs = [] + for ans in answers: + source_text = f'The context: {context}. The answer is: {ans} . This is a bool task, generate question: ' + target_text = f'{question}' + outputs.append({'source_text':source_text,'target_text':target_text}) + return outputs + @classmethod + def qg_input_extractive_qa(cls, context, question, answers,options=None): + outputs = [] + for ans in answers: + context = context.replace(ans,f' {ans} ') + source_text = f'The context: {context}. The answer is: {ans}. This is a extractive task, generate question: ' + target_text = f'{question}' + outputs.append({'source_text': source_text, 'target_text': target_text}) + return outputs + + @classmethod + def qg_input_extractive_qapairs(cls, context, question, answers, options=None): + outputs = [] + for ans in answers: + source_text = f'Generate question and the answer: [Context] {context}.' + target_text = f'[question] {question} [answer] {ans}' + outputs.append({'source_text': source_text, 'target_text': target_text}) + return outputs + @classmethod + def qg_input_multirc(cls, context, question, answers, options=None): + outputs = [] + for ans in answers: + source_text = f'The context: {context}. The options are: {options}. The answer is: {ans} . This is a machine reading comprehension task, generate question: ' + target_text = f'{question}' + outputs.append({'source_text': source_text, 'target_text': target_text}) + return outputs + @classmethod + def qg_intput_multirc_negoption(cls, context, question, answers, options=None): + outputs = [] + for ans in answers: + for option in options: + if option!=ans: + source_text = f'Context: {context}. Answer is: {ans} . Generate false option: ' + target_text = f'{option}' + outputs.append({'source_text': source_text, 'target_text': target_text}) + return outputs + @classmethod + def qg_intput_multirc_qapairs(cls, context, question, answers, options=None): + outputs = [] + for ans in answers: + source_text = f'Generate question and answer: [Context] {context}' + target_text = f'[Question] {question} [Answer] {ans}' + outputs.append({'source_text': source_text, 'target_text': target_text}) + return outputs + @classmethod + def qg_intput_multirc_negoption_withq(cls, context, question, answers, options=None): + outputs = [] + for ans in answers: + for option in options: + if option != ans: + source_text = f'Generate false option: [Context] {context}. [Question] {question}. [Answer] {ans} .' + target_text = f'{option}' + outputs.append({'source_text': source_text, 'target_text': target_text}) + return outputs + + @classmethod + def qg_intput_multirc_qfirst(cls, context, question, answers, options=None): + outputs = [] + for ans in answers: + source_text = f'Generate question: Context: {context}. Answer: {ans} . ' + target_text = f'{question}' + outputs.append({'source_text': source_text, 'target_text': target_text}) + return outputs \ No newline at end of file diff --git a/diana/data_process/tasks.py b/diana/data_process/tasks.py new file mode 100755 index 0000000..297f86c --- /dev/null +++ b/diana/data_process/tasks.py @@ -0,0 +1,156 @@ +import logging +import t5 +import os +import json +import functools +import tensorflow as tf +import tensorflow_datasets as tfds + +DATASETS = [ + "ai2_science_middle", + "ai2_science_elementary", + "arc_hard", + "arc_easy", + "mctest", + "mctest_corrected_the_separator", + "natural_questions", + "quoref", + "squad1_1", + "squad2", + "boolq", + "multirc", + "newsqa", + "race_string", + "ropes", + "drop", + "narrativeqa", + "openbookqa", + "qasc", + "boolq_np", + "contrast_sets_boolq", + "contrast_sets_drop", + "contrast_sets_quoref", + "contrast_sets_ropes", + "commonsenseqa", + "qasc_with_ir", + "openbookqa_with_ir", + "arc_easy_with_ir", + "arc_hard_with_ir", + "ambigqa", + "natural_questions_direct_ans", + "natural_questions_with_dpr_para", + "winogrande_xs", + "winogrande_s", + "winogrande_m", + "winogrande_l", + "winogrande_xl", + "social_iqa", + "physical_iqa", +] + +DATA_DIR = f"gs://unifiedqa/data/" + +def dataset_preprocessor(ds): + def normalize_text(text): + """Lowercase and remove quotes from a TensorFlow string.""" + text = tf.strings.lower(text) + text = tf.strings.regex_replace(text, "'(.*)'", r"\1") + return text + + def to_inputs_and_targets(ex): + return { + "inputs": normalize_text(ex["inputs"]), + "targets": normalize_text(ex["targets"]) + } + + return ds.map(to_inputs_and_targets, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + +def get_path(data_dir1, split): + tsv_path = { + "train": os.path.join(data_dir1, "train.tsv"), + "dev": os.path.join(data_dir1, "dev.tsv"), + "test": os.path.join(data_dir1, "test.tsv") + } + return tsv_path[split] + + +def dataset_fn(split, shuffle_files=False, dataset=""): + # We only have one file for each split. + del shuffle_files + + # Load lines from the text file as examples. + ds = tf.data.TextLineDataset(get_path(DATA_DIR + dataset, split)) + # Split each "\t" example into (question, answer) tuple. + print(" >>>> about to read csv . . . ") + ds = ds.map( + functools.partial(tf.io.decode_csv, record_defaults=["", ""], + field_delim="\t", use_quote_delim=False), + num_parallel_calls=tf.data.experimental.AUTOTUNE) + # print(" >>>> after reading csv . . . ") + # Map each tuple to a {"question": ... "answer": ...} dict. + ds = ds.map(lambda *ex: dict(zip(["inputs", "targets"], ex))) + # print(" >>>> after mapping . . . ") + return ds + + +for dataset in DATASETS: + print(f" >>>> reading dataset: {dataset}") + t5.data.set_tfds_data_dir_override(DATA_DIR + dataset) + t5.data.TaskRegistry.add( + f"{dataset}_task", + # Supply a function which returns a tf.data.Dataset. + dataset_fn=functools.partial(dataset_fn, dataset=dataset), + splits=["train", "dev", "test"], + # Supply a function which preprocesses text from the tf.data.Dataset. + text_preprocessor=[dataset_preprocessor], + # Use the same vocabulary that we used for pre-training. + sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH, + # Lowercase targets before computing metrics. + postprocess_fn=t5.data.postprocessors.lower_text, + metric_fns=[t5.evaluation.metrics.accuracy], + ) + print(f" >>>> adding one mixture per dataset: `{dataset}_mixture`") + t5.data.MixtureRegistry.add( + f"{dataset}_mixture", [f"{dataset}_task"], default_rate=1.0 + ) + +# dataset-pair mixtures +for dataset1 in DATASETS: + for dataset2 in DATASETS: + if dataset1 == dataset2: + continue + print(f" >>>> adding one mixture for dataset-pair: `{dataset1}_{dataset2}_mixture`") + t5.data.MixtureRegistry.add( + f"{dataset1}_{dataset2}_mixture", + [f"{dataset1}_task", f"{dataset2}_task"], + default_rate=1.0 + ) + +# union model: used for training UnifiedQA +union_datasets = [ + "narrativeqa", + "ai2_science_middle", "ai2_science_elementary", + "arc_hard", "arc_easy", + "mctest_corrected_the_separator", + "squad1_1", "squad2", + "boolq", + "race_string", + "openbookqa", +] +print(f" >>>> adding one mixture for `union_mixture`") +t5.data.MixtureRegistry.add( + f"union_mixture", + [f"{d}_task" for d in union_datasets], + default_rate=1.0 +) + +# leave-one-out +for dd in union_datasets: + filtered_datasets = [f"{d}_task" for d in union_datasets if d != dd] + assert len(filtered_datasets) < len(union_datasets) + t5.data.MixtureRegistry.add( + f"union_minus_{dd}_mixture", + filtered_datasets, + default_rate=1.0 + ) diff --git a/diana/downstream/create_l2p.py b/diana/downstream/create_l2p.py new file mode 100644 index 0000000..af08f46 --- /dev/null +++ b/diana/downstream/create_l2p.py @@ -0,0 +1,782 @@ +#!/usr/bin/env python +# coding=utf-8 + +import logging +import os +import torch +import copy,random +import sys +import json +from dataclasses import dataclass, field +from typing import Optional +from typing import List, Optional, Tuple +# from data import QAData +sys.path.append("../") +from models.nopt5 import T5ForConditionalGeneration as PromptT5 +# from models.promptT5 import PromptT5 +from downstream.dataset_processors import * +from downstream.trainer import QuestionAnsweringTrainer +import datasets +import numpy as np +from datasets import load_dataset, load_metric,load_from_disk +import os + +from functools import partial +os.environ["WANDB_DISABLED"] = "true" +import transformers +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + M2M100Tokenizer, + MBart50Tokenizer, + EvalPrediction, + MBart50TokenizerFast, + MBartTokenizer, + MBartTokenizerFast, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + default_data_collator, + set_seed, +) +from pathlib import Path +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +# check_min_version("4.12.0.dev0") + +logger = logging.getLogger(__name__) + +# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes. +MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer] +format2id = {'extractive':0,'abstractive':1,'multichoice':2} +format2dataset = { + 'extractive':['squad','extractive','newsqa','quoref'], #,'ropes' + 'abstractive':['narrativeqa','abstractive','nqopen','drop'], + 'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream'] +} +seed_datasets = ['race','narrativeqa','squad'] +dataset2format= {} +task2id = {} +for k,vs in format2dataset.items(): + for v in vs: + dataset2format[v] = k +task2format={} +for k,v in dataset2format.items(): + if v=="extractive": + task2format[k]=0 + elif v=="abstractive": + task2format[k]=1 + else: + task2format[k]=2 + # task2id[v] = len(task2id.keys()) +# task2id = {v:k for k,v in enumerate(task_list)} +task2id = {'squad': 0, 'extractive': 1, 'narrativeqa': 2, 'abstractive': 3, 'race': 4, 'multichoice': 5, 'boolq': 6, 'bool': 7, 'newsqa':8,'quoref':9,'ropes':10,'drop':11,'nqopen':12,'boolq_np':13,'openbookqa':14,'mctest':15,'social_iqa':16,'dream':17} + +# task2id = {'squad': 0, 'extractive': 0, 'narrativeqa': 1, 'abstractive':1 , 'race': 2, 'multichoice': 2, 'boolq': 3, 'bool': 3, 'newsqa':8,'quoref':9,'ropes':10,'drop':11,'nqopen':12,'boolq_np':13,'openbookqa':14,'mctest':15,'social_iqa':16,'dream':17} +# print(task2id) + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + fix_t5: Optional[bool] = field( + default=False, + metadata={"help": "whether to fix the main parameters of t5"} + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + source_lang: str = field(default=None, metadata={"help": "Source language id for translation."}) + target_lang: str = field(default=None, metadata={"help": "Target language id for translation."}) + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."}) + validation_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on " + "a jsonlines file." + }, + ) + test_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + after_train: Optional[str] = field(default=None) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + }, + ) + max_seq_length: int = field( + default=384, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + source_prefix: Optional[str] = field( + default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + forced_bos_token: Optional[str] = field( + default=None, + metadata={ + "help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`." + "Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token " + "needs to be the target language token.(Usually it is the target language token)" + }, + ) + do_lowercase: bool = field( + default=False, + metadata={ + 'help':'Whether to process input into lowercase' + } + ) + append_another_bos: bool = field( + default=False, + metadata={ + 'help':'Whether to add another bos' + } + ) + context_column: Optional[str] = field( + default="context", + metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."}, + ) + question_column: Optional[str] = field( + default="question", + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + answer_column: Optional[str] = field( + default="answers", + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + max_task_num: Optional[int] = field( + default=30, + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + qa_task_type_num: Optional[int] = field( + default=4, + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + prompt_number: Optional[int] = field( + default=40, + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + add_task_prompt: Optional[bool] = field( + default=False, + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + reload_from_trained_prompt: Optional[bool] = field( + default=False, + metadata={"help": "whether to reload prompt from trained prompt"}, + ) + trained_prompt_path:Optional[str] = field( + default = None, + metadata={ + 'help':'the path storing trained prompt embedding' + } + ) + load_from_format_task_id: Optional[bool] = field( + default=False, + metadata={"help": "whether to reload prompt from format-corresponding task prompt"} + ) + + + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + # elif self.source_lang is None or self.target_lang is None: + # raise ValueError("Need to specify the source language and the target language.") + + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + # assert extension == "json", "`train_file` should be a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + # assert extension == "json", "`validation_file` should be a json file." + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + + +# + +def main(): + + def preprocess_valid_function(examples): + preprocess_fn = preprocess_proqa#dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, "question","context","answer") + model_inputs = tokenizer(inputs, max_length=1024, padding=False, truncation=True) + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=128, padding=False, truncation=True) + model_inputs["example_id"] = [] + + for i in range(len(model_inputs["input_ids"])): + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = i #sample_mapping[i] + model_inputs["example_id"].append(examples["id"][sample_index]) + + + + gen_prompt_ids = [-(i+1) for i in range(1520,1520+20)] + format_id = task2format[dataset_name] + format_prompt_id_start = 500 + format_prompt_ids = [-(i+1) for i in range(format_prompt_id_start + format_id * 40, + format_prompt_id_start + (format_id + 1) * 40)] + task_id = task2id[dataset_name] + task_prompt_id_start = 0 + task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * 40, + task_prompt_id_start + (task_id + 1) * 40)] + + + domain_prompt_id_start = 800 + domain_prompt_number = 20 + domain_prompt_ids = [- (i + 1) for i in range(domain_prompt_id_start, + domain_prompt_id_start + 20)]*5 + input_ids = copy.deepcopy( + [gen_prompt_ids+format_prompt_ids+task_prompt_ids + domain_prompt_ids+input_ids for input_ids in model_inputs['input_ids']]) + model_inputs['input_ids'] = input_ids # [format_prompt_ids+input_ids for input_ids in model_inputs['input_ids']] + model_inputs['attention_mask'] = [[1] * 200 + attention_mask for attention_mask in + model_inputs['attention_mask']] + model_inputs["labels"] = labels["input_ids"] + + + return model_inputs + + def compute_metrics(p: EvalPrediction): + return metric.compute(predictions=p.predictions, references=p.label_ids) + + + def save_prompt_embedding(model,path): + prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight'] + save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id} + prompt_path = os.path.join(path,'prompt_embedding_info') + torch.save(save_prompt_info,prompt_path) + logger.info(f'Saving prompt embedding information to {prompt_path}') + + + def preprocess_function(examples): + preprocess_fn = preprocess_proqa#dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, "question","context","answer") + model_inputs = tokenizer(inputs, max_length=1024, padding=False, truncation=True) + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + labels = tokenizer(targets, max_length=128, padding=False, truncation=True) + + model_inputs["labels"] = labels["input_ids"] + gen_prompt_ids = [-(i+1) for i in range(1520,1520+20)] + format_id = task2format[dataset_name] + format_prompt_id_start = 500 + format_prompt_ids = [-(i+1) for i in range(format_prompt_id_start + format_id * 40, + format_prompt_id_start + (format_id + 1) * 40)] + task_id = task2id[dataset_name] + task_prompt_id_start = 0 + task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * 40, + task_prompt_id_start + (task_id + 1) * 40)] + + + domain_prompt_id_start = 800 + domain_prompt_number = 20 + domain_prompt_ids = [- (i + 1) for i in range(domain_prompt_id_start, + domain_prompt_id_start + 20)]*5 + input_ids = copy.deepcopy( + [gen_prompt_ids+format_prompt_ids+task_prompt_ids + domain_prompt_ids+input_ids for input_ids in model_inputs['input_ids']]) + model_inputs['input_ids'] = input_ids # [format_prompt_ids+input_ids for input_ids in model_inputs['input_ids']] + model_inputs['attention_mask'] = [[1] * 200 + attention_mask for attention_mask in + model_inputs['attention_mask']] + return model_inputs + + + + def dataset_name_to_func(dataset_name): + mapping = { + 'squad': preprocess_sqaud_batch, + 'squad_v2': preprocess_sqaud_batch, + 'boolq': preprocess_boolq_batch, + 'narrativeqa': preprocess_narrativeqa_batch, + 'race': preprocess_race_batch, + 'newsqa': preprocess_newsqa_batch, + 'quoref': preprocess_sqaud_batch, + 'ropes': preprocess_ropes_batch, + 'drop': preprocess_drop_batch, + 'nqopen': preprocess_sqaud_abstractive_batch, + # 'multirc': preprocess_boolq_batch, + 'boolq_np': preprocess_boolq_batch, + 'openbookqa': preprocess_openbookqa_batch, + 'mctest': preprocess_race_batch, + 'social_iqa': preprocess_social_iqa_batch, + 'dream': preprocess_dream_batch, + } + return mapping[dataset_name] + + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + if 'same' in model_args.model_name_or_path: + task2id = {'squad': 0, 'extractive': 0, 'narrativeqa': 1, 'abstractive': 1, 'race': 2, 'multichoice': 2, + 'boolq': 3, 'bool': 3, 'newsqa': 8, 'quoref': 9, 'ropes': 10, 'drop': 11, 'nqopen': 12, + 'boolq_np': 13, 'openbookqa': 14, 'mctest': 15, 'social_iqa': 16, 'dream': 17} + else: + task2id = {'squad': 0, 'extractive': 1, 'narrativeqa': 2, 'abstractive': 3, 'race': 4, 'multichoice': 5, + 'boolq': 6, 'bool': 7, 'newsqa': 8, 'quoref': 9, 'ropes': 10, 'drop': 11, 'nqopen': 12, + 'boolq_np': 13, 'openbookqa': 14, 'mctest': 15, 'social_iqa': 16, 'dream': 17} + + dataset_name_to_metric = { + 'squad': 'squad', + 'squad_v2': 'metric/squad_v2_local/squad_v2_local.py', + 'newsqa': 'metric/squad_v2_local/squad_v2_local.py', + 'boolq': 'accuracy', + 'narrativeqa': 'rouge', + 'race': 'accuracy', + 'quoref': 'squad', + 'ropes': 'squad', + 'drop': 'squad', + 'nqopen': 'squad', + # 'multirc': 'accuracy', + 'boolq_np': 'accuracy', + 'openbookqa': 'accuracy', + 'mctest': 'accuracy', + 'social_iqa': 'accuracy', + 'dream': 'accuracy', + } + + + + + + + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + if data_args.source_prefix is None and model_args.model_name_or_path in [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + ]: + logger.warning( + "You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with " + "`--source_prefix 'translate English to German: ' `" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]', + # '[OPTIONS]']) + tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]'] + special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]', + '[OPTIONS]']} + + tokenizer.add_tokens(tokens_to_add) + num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) + added_tokens = tokenizer.get_added_vocab() + logger.info('Added tokens: {}'.format(added_tokens)) + + model = PromptT5.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + # task_num = data_args.max_task_num, + # prompt_num = data_args.prompt_number, + # format_num = data_args.qa_task_type_num, + # add_task_prompt = False + ) + + model.resize_token_embeddings(len(tokenizer)) + + + #reload format specific task-prompt for newly involved task +#format_prompts###task_promptsf + + data_args.reload_from_trained_prompt = False#@ + data_args.load_from_format_task_id = False#@ + ### before pretrain come !!!!!! + + if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt: + task_start_id = data_args.prompt_number * len(format2dataset.keys()) + task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number + format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number + model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:] + logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}') + # print(dataset2format[data_args.dataset_name]) + # print(data_args.dataset_name) + elif data_args.reload_from_trained_prompt: + assert data_args.trained_prompt_path,'Must specify the path of stored prompt' + prompt_info = torch.load(data_args.trained_prompt_path) + assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}' + assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match' + task_start_id = data_args.prompt_number * len(format2dataset.keys()) + + task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number + logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number)) + # assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0 + model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] + format_id = format2id[dataset2format[data_args.dataset_name]] + model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] + logger.info( + f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}') + + # Set decoder_start_token_id + if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): + if isinstance(tokenizer, MBartTokenizer): + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang] + else: + model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang) + + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + + + if training_args.local_rank == -1 or training_args.no_cuda: + device = torch.device("cuda") + n_gpu = torch.cuda.device_count() + + + + + + # Temporarily set max_target_length for training. + max_target_length = data_args.max_target_length + padding = "max_length" if data_args.pad_to_max_length else False + + if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): + logger.warning( + "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" + f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" + ) + + + question_column = data_args.question_column + context_column = data_args.context_column + answer_column = data_args.answer_column + # import random + if data_args.max_source_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length) + # Data collator + label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + if data_args.pad_to_max_length: + data_collator = default_data_collator + else: + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=8 if training_args.fp16 else None, + ) + + #start + train_dataloaders = {} + eval_dataloaders = {} + replay_dataloaders = {} + for ds_name in task2id.keys(): + if (not ds_name in ["extractive","abstractive","multichoice","bool","boolq","boolq_np","ropes"]): + # data_args.dataset_name = cur_dataset + cur_dataset = ds_name + data_args.dataset_name = cur_dataset + if True: + # Downloading and loading a dataset from the hub. + if not data_args.dataset_name in ['newsqa', 'nqopen', 'mctest', 'social_iqa']: + if data_args.dataset_name == "race": + data_args.dataset_config_name = "all" + elif data_args.dataset_name == "openbookqa": + data_args.dataset_config_name = "main" + elif data_args.dataset_name == "dream": + data_args.dataset_config_name = "plain_text" + else: + data_args.dataset_config_name = "plain_text" + raw_datasets = load_dataset( + data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir + ) + if data_args.dataset_name in ['ropes']: + # add answer_start (not used for squad evaluation but required) + def add_answer_start(example): + example['answers'].update({"answer_start": [0]}) + return example + raw_datasets = raw_datasets.map(add_answer_start) + elif data_args.dataset_name in ['drop']: + # add answer_start (not used for squad evaluation but required) + # add answers (for squad evaluation) + def add_answers(example): + answers = [] + answer_start = [] + for _a in example['answers_spans']['spans']: + answers.append(_a) + answer_start.append(-1) + example['answers'] = {"text": answers, "answer_start": answer_start} + return example + raw_datasets = raw_datasets.map(add_answers) + column_names = raw_datasets["validation"].column_names + + + else: + data_files = {} + basic_file = "../data_process/data/"+data_args.dataset_name+"/" + data_files["train"] = basic_file+"train.json" + # extension = data_args.train_file.split(".")[-1] + + data_files["validation"] = basic_file+"dev.json" + # extension = data_args.validation_file.split(".")[-1] + + if data_args.dataset_name in ['newsqa', 'nqopen', 'multirc', 'boolq_np', 'mctest', 'social_iqa']: + raw_datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir) + else: + print(f"Unknown dataset {data_args.dataset_name}") + raise NotImplementedError + column_names = raw_datasets["validation"].column_names + metric = load_metric(dataset_name_to_metric[data_args.dataset_name]) + + if "train" not in raw_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = raw_datasets["train"] + + if True: + all_num = list(range(0, len(train_dataset))) + random.shuffle(all_num) + selected_indices = all_num[:100] + replay_dataset = train_dataset.select(selected_indices) + + # train_dataset = train_dataset.select(range(data_args.max_train_samples)) + + with training_args.main_process_first(desc="train dataset map pre-processing"): + replay_dataset = replay_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=True, + desc="Running tokenizer on replay dataset", + ) + replay_dataloaders[ds_name] = replay_dataset + # replay_dataset = load_from_disk("./processed/{}-replay.hf".format(ds_name)) + # print(replay_dataset) + with training_args.main_process_first(desc="train dataset map pre-processing"): + train_dataset = train_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=True, + desc="Running tokenizer on train dataset", + ) + # print(train_dataset) + train_dataloaders[ds_name] = train_dataset + train_dataset.save_to_disk("./ours/{}-train.hf".format(ds_name)) + # train_dataset = load_from_disk("./processed/{}-train.hf".format(ds_name)) + max_target_length = data_args.val_max_target_length + if "validation" not in raw_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_examples = raw_datasets["validation"] + def add_id(example,index): + example.update({'id':index}) + return example + if 'id' not in eval_examples.features.keys(): + eval_examples = eval_examples.map(add_id,with_indices=True) + if data_args.max_eval_samples is not None: + eval_examples = eval_examples.select(range(data_args.max_eval_samples)) + with training_args.main_process_first(desc="validation dataset map pre-processing"): + eval_dataset = eval_examples.map( + preprocess_validation_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=True, + desc="Running tokenizer on validation dataset", + ) + eval_dataloaders[ds_name] = (eval_dataset,eval_examples) + eval_dataset.save_to_disk("./ours/{}-eval.hf".format(ds_name)) + + eval_examples.save_to_disk("./ours/{}-evalex.hf".format(ds_name)) + + + + languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None] + if len(languages) > 0: + kwargs["language"] = languages + return None + + +# - + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/diana/downstream/dataset_processors.py b/diana/downstream/dataset_processors.py new file mode 100755 index 0000000..2b056b5 --- /dev/null +++ b/diana/downstream/dataset_processors.py @@ -0,0 +1,192 @@ +import sys +sys.path.append('../data_process') +from typing import List, Optional, Tuple +from QAInput import StructuralQAInput as QAInput +#from QAInput import QAInputNoPrompt as QAInput + +def preprocess_sqaud_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_sqaud_abstractive_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_boolq_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + question_column, context_column, answer_column = 'question', 'passage', 'answer' + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)] + targets = [str(ans) for ans in answers] #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + # print(inputs,targets) + return inputs, targets + +def preprocess_boolq_batch_pretrain( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0].capitalize() if len(answer["text"]) > 0 else "" for answer in answers] + # print(inputs,targets) + return inputs, targets + +def preprocess_narrativeqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = [exp['summary']['text'] for exp in examples['document']] + questions = [exp['text'] for exp in examples['question']] + answers = [ans[0]['text'] for ans in examples['answers']] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = answers #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_narrativeqa_batch_pretrain( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_drop_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = examples['passage'] + questions = examples['question'] + answers = examples['answers'] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_race_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = examples['article'] + questions = examples['question'] + all_options = examples['options'] + answers = examples['answer'] + options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}; D. {options[3]}' for options in all_options] + inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)] + ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3 } + targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)] + return inputs, targets + +def preprocess_newsqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)] + # inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + + +def preprocess_ropes_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + backgrounds = examples["background"] + situations = examples["situation"] + answers = examples[answer_column] + + inputs = [QAInput.qg_input_extractive_qa(" ".join([background, situation]), question) for question, background, situation in zip(questions, backgrounds, situations)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_openbookqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples['question_stem'] + all_options = examples['choices'] + answers = examples['answerKey'] + options_texts = [f"options: A. {options['text'][0]}; B. {options['text'][1]}; C. {options['text'][2]}; D. {options['text'][3]}" for options in all_options] + inputs = [QAInput.qg_input_multirc("", question, ops) for question, ops in zip(questions, options_texts)] + ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3} + targets = [options['text'][ans_map[answer]] for options, answer in zip(all_options, answers)] + return inputs, targets + +def preprocess_social_iqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = examples['article'] + questions = examples['question'] + all_options = examples['options'] + answers = examples['answer'] + options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options] + inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)] + ans_map = {'A': 0, 'B': 1, 'C': 2,} + targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)] + return inputs, targets + +def preprocess_dream_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = [" ".join(dialogue) for dialogue in examples['dialogue']] + questions = examples['question'] + all_options = examples['choice'] + answers = examples['answer'] + answer_idxs = [options.index(answer) for answer, options in zip(answers, all_options)] + options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options] + inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)] + targets = answers + return inputs, targets \ No newline at end of file diff --git a/diana/downstream/diana-uni.sh b/diana/downstream/diana-uni.sh new file mode 100644 index 0000000..a3ab9f8 --- /dev/null +++ b/diana/downstream/diana-uni.sh @@ -0,0 +1,53 @@ +function fulldata_hfdata() { +SAVING_PATH=$1 +PRETRAIN_MODEL_PATH=$2 +TRAIN_BATCH_SIZE=$3 +EVAL_BATCH_SIZE=$4 + + + + + +mkdir -p ${SAVING_PATH} + +python -u diana.py \ + --model_name_or_path ${PRETRAIN_MODEL_PATH} \ + --output_dir ${SAVING_PATH} \ + --dataset_name squad \ + --do_train \ + --do_eval \ + --per_device_train_batch_size ${TRAIN_BATCH_SIZE} \ + --per_device_eval_batch_size ${EVAL_BATCH_SIZE} \ + --overwrite_output_dir \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 5 \ + --warmup_ratio 0.1 \ + --logging_steps 5 \ + --learning_rate 1e-4 \ + --predict_with_generate \ + --num_beams 4 \ + --save_strategy no \ + --evaluation_strategy no \ + --weight_decay 1e-2 \ + --max_source_length 512 \ + --label_smoothing_factor 0.1 \ + --do_lowercase True \ + --load_best_model_at_end True \ + --greater_is_better True \ + --save_total_limit 10 \ + --ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log +} + +SAVING_PATH=$1 +TRAIN_BATCH_SIZE=$2 +EVAL_BATCH_SIZE=$3 + + + + + +MODE=try + + +SAVING_PATH=${SAVING_PATH}/lifelong/${MODE} +fulldata_hfdata ${SAVING_PATH} t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE} diff --git a/diana/downstream/diana.py b/diana/downstream/diana.py new file mode 100644 index 0000000..f160ca6 --- /dev/null +++ b/diana/downstream/diana.py @@ -0,0 +1,827 @@ +#!/usr/bin/env python +# coding=utf-8 +import logging +import os +os.environ['NCCL_ASYNC_ERROR_HANDLING']='1' +import torch +import copy,random +import sys +import json +from dataclasses import dataclass, field +from typing import Optional +from typing import List, Optional, Tuple +from sklearn.cluster import KMeans +# from data import QAData +sys.path.append("../") +from models.metat5 import T5ForConditionalGeneration as PromptT5 +from downstream.dataset_processors import * +from downstream.l2ptrainer import QuestionAnsweringTrainer +import datasets +import numpy as np +from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets +import os + +from functools import partial +os.environ["WANDB_DISABLED"] = "true" +import transformers +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + M2M100Tokenizer, + MBart50Tokenizer, + EvalPrediction, + MBart50TokenizerFast, + MBartTokenizer, + MBartTokenizerFast, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + default_data_collator, + set_seed, +) +from pathlib import Path +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +logger = logging.getLogger(__name__) + +# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes. +MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer] +format2id = {'extractive':0,'abstractive':1,'multichoice':2} +format2dataset = { + 'extractive':['squad','extractive','newsqa','quoref'], + 'abstractive':['narrativeqa','abstractive','nqopen','drop'], + 'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream'] +} +seed_datasets = ['race','narrativeqa','squad'] +dataset2format= {} + +for k,vs in format2dataset.items(): + for v in vs: + dataset2format[v] = k + +task2id = {'squad': 0, 'narrativeqa': 1, 'race': 2, 'newsqa':3,'quoref':4,'drop':5,'nqopen':6,'openbookqa':7,'mctest':8,'social_iqa':9,'dream':10} + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + fix_t5: Optional[bool] = field( + default=False, + metadata={"help": "whether to fix the main parameters of t5"} + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + source_lang: str = field(default=None, metadata={"help": "Source language id for translation."}) + target_lang: str = field(default=None, metadata={"help": "Target language id for translation."}) + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."}) + validation_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on " + "a jsonlines file." + }, + ) + test_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + after_train: Optional[str] = field(default=None) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + }, + ) + max_seq_length: int = field( + default=384, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + source_prefix: Optional[str] = field( + default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + forced_bos_token: Optional[str] = field( + default=None, + metadata={ + "help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`." + "Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token " + "needs to be the target language token.(Usually it is the target language token)" + }, + ) + do_lowercase: bool = field( + default=False, + metadata={ + 'help':'Whether to process input into lowercase' + } + ) + append_another_bos: bool = field( + default=False, + metadata={ + 'help':'Whether to add another bos' + } + ) + context_column: Optional[str] = field( + default="context", + metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."}, + ) + question_column: Optional[str] = field( + default="question", + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + answer_column: Optional[str] = field( + default="answers", + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + max_task_num: Optional[int] = field( + default=30, + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + qa_task_type_num: Optional[int] = field( + default=4, + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + prompt_number: Optional[int] = field( + default=100, + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + add_task_prompt: Optional[bool] = field( + default=False, + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + reload_from_trained_prompt: Optional[bool] = field( + default=False, + metadata={"help": "whether to reload prompt from trained prompt"}, + ) + trained_prompt_path:Optional[str] = field( + default = None, + metadata={ + 'help':'the path storing trained prompt embedding' + } + ) + load_from_format_task_id: Optional[bool] = field( + default=False, + metadata={"help": "whether to reload prompt from format-corresponding task prompt"} + ) + + + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + # elif self.source_lang is None or self.target_lang is None: + # raise ValueError("Need to specify the source language and the target language.") + + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + # assert extension == "json", "`train_file` should be a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + # assert extension == "json", "`validation_file` should be a json file." + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + + +# + +def main(): + + + + def preprocess_validation_function(examples): + preprocess_fn = dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column) + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + # Setup the tokenizer for targets + model_inputs["example_id"] = [] + + for i in range(len(model_inputs["input_ids"])): + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = i #sample_mapping[i] + model_inputs["example_id"].append(examples["id"][sample_index]) + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + model_inputs["labels"] = labels["input_ids"] + + if True: + pass + # logger.info(f'Loading task {data_args.dataset_name} prompt') + # format_id = format2id[dataset2format[data_args.dataset_name]] + # task_id = task2id[data_args.dataset_name] + # format_prompt_ids = [- (i + 1) for i in range(format_id * data_args.prompt_number, ( + # format_id + 1) * data_args.prompt_number)] # list(range(-(format_id * data_args.prompt_number+1), -((format_id + 1) * data_args.prompt_number+1))) + # task_prompt_id_start = len(format2id.keys()) * data_args.prompt_number + # logger.info('Prompt ids {}: {}'.format(task_prompt_id_start + task_id * data_args.prompt_number, + # task_prompt_id_start + (task_id + 1) * data_args.prompt_number)) + # task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * data_args.prompt_number, + # task_prompt_id_start + (task_id + 1) * data_args.prompt_number)] + # input_ids = copy.deepcopy( + # [format_prompt_ids + task_prompt_ids + input_ids for input_ids in model_inputs['input_ids']]) + #input_ids = copy.deepcopy([format_prompt_ids + input_ids for input_ids in model_inputs['input_ids']]) + # model_inputs['input_ids'] = input_ids + # model_inputs['attention_mask'] = [[1] * data_args.prompt_number * 2 + attention_mask for attention_mask in + # model_inputs['attention_mask']] + + # input_ids = copy.deepcopy([input]) + return model_inputs + def compute_metrics(p: EvalPrediction): + return metric.compute(predictions=p.predictions, references=p.label_ids) + + + def save_prompt_embedding(model,path): + prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight'] + save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id} + prompt_path = os.path.join(path,'prompt_embedding_info') + torch.save(save_prompt_info,prompt_path) + logger.info(f'Saving prompt embedding information to {prompt_path}') + + + + def preprocess_function(examples): + preprocess_fn = dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column) + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + def save_load_diverse_sample(model,trainset): + with torch.no_grad(): + device = 'cuda:0' + search_upbound = len(trainset)//4 + query_idxs = [None]*30 + keys = model.encoder.domain_keys + for idx,item in enumerate(trainset.select(range(search_upbound))): + query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device), + attention_mask=torch.tensor([item['attention_mask']]).long().to(device), + return_dict=True) + result = torch.matmul(query,keys.t()) + result = torch.topk(result,5).indices[0].cpu().numpy().tolist() + key_sel = None + for key_idx in result: + if query_idxs[key_idx] is None or len(query_idxs[key_idx])<3: + key_sel = key_idx + break + if key_sel is not None: + if query_idxs[key_sel] is None: + query_idxs[key_sel] = [idx] + else: + query_idxs[key_sel].append(idx) + total_idxs = [] + for item in query_idxs: + try: + total_idxs.extend(item[:3]) + except: + total_idxs.extend(random.sample(list(range(search_upbound,len(trainset))),3)) + total_idxs = list(set(total_idxs)) + total_idxs = random.sample(total_idxs,50) + sub_set = trainset.select(total_idxs) + + features = [] + for idx,item in enumerate(sub_set): + query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device), + attention_mask=torch.tensor([item['attention_mask']]).long().to(device), + return_dict=True) + features.append(query.detach().cpu().numpy()) + + + + return sub_set,features + + + + + + def dataset_name_to_func(dataset_name): + mapping = { + 'squad': preprocess_sqaud_batch, + 'squad_v2': preprocess_sqaud_batch, + 'boolq': preprocess_boolq_batch, + 'narrativeqa': preprocess_narrativeqa_batch, + 'race': preprocess_race_batch, + 'newsqa': preprocess_newsqa_batch, + 'quoref': preprocess_sqaud_batch, + 'ropes': preprocess_ropes_batch, + 'drop': preprocess_drop_batch, + 'nqopen': preprocess_sqaud_abstractive_batch, + # 'multirc': preprocess_boolq_batch, + 'boolq_np': preprocess_boolq_batch, + 'openbookqa': preprocess_openbookqa_batch, + 'mctest': preprocess_race_batch, + 'social_iqa': preprocess_social_iqa_batch, + 'dream': preprocess_dream_batch, + } + return mapping[dataset_name] + + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + dataset_name_to_metric = { + 'squad': 'metric/squad_v1_local/squad_v1_local.py', + 'squad_v2': 'metric/squad_v2_local/squad_v2_local.py', + 'newsqa': 'metric/squad_v2_local/squad_v2_local.py', + 'boolq': 'metric/accuracy.py', + 'narrativeqa': 'metric/rouge_local/rouge_metric.py', + 'race': 'metric/accuracy.py', + 'quoref': 'metric/squad_v1_local/squad_v1_local.py', + 'ropes': 'metric/squad_v1_local/squad_v1_local.py', + 'drop': 'metric/squad_v1_local/squad_v1_local.py', + 'nqopen': 'metric/squad_v1_local/squad_v1_local.py', + 'boolq_np': 'metric/accuracy.py', + 'openbookqa': 'metric/accuracy.py', + 'mctest': 'metric/accuracy.py', + 'social_iqa': 'metric/accuracy.py', + 'dream': 'metric/accuracy.py', + } + + + + + + + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + if data_args.source_prefix is None and model_args.model_name_or_path in [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + ]: + logger.warning( + "You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with " + "`--source_prefix 'translate English to German: ' `" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]', + # '[OPTIONS]']) + tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]'] + special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]', + '[OPTIONS]']} + + tokenizer.add_tokens(tokens_to_add) + num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) + added_tokens = tokenizer.get_added_vocab() + logger.info('Added tokens: {}'.format(added_tokens)) + + model = PromptT5.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + # task_num = data_args.max_task_num, + # prompt_num = data_args.prompt_number, + # format_num = data_args.qa_task_type_num, + # add_task_prompt = False + ) + + model.resize_token_embeddings(len(tokenizer)) + + + #reload format specific task-prompt for newly involved task +#format_prompts###task_promptsf + + data_args.reload_from_trained_prompt = False#@ + data_args.load_from_format_task_id = False#@ + ### before pretrain come !!!!!! + + if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt: + task_start_id = data_args.prompt_number * len(format2dataset.keys()) + task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number + format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number + model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:] + logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}') + # print(dataset2format[data_args.dataset_name]) + # print(data_args.dataset_name) + elif data_args.reload_from_trained_prompt: + assert data_args.trained_prompt_path,'Must specify the path of stored prompt' + prompt_info = torch.load(data_args.trained_prompt_path) + assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}' + assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match' + task_start_id = data_args.prompt_number * len(format2dataset.keys()) + + task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number + logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number)) + # assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0 + model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] + format_id = format2id[dataset2format[data_args.dataset_name]] + model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] + logger.info( + f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}') + + # Set decoder_start_token_id + if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): + if isinstance(tokenizer, MBartTokenizer): + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang] + else: + model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang) + + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + + + if training_args.local_rank == -1 or training_args.no_cuda: + device = torch.device("cuda") + n_gpu = torch.cuda.device_count() + + + + + + # Temporarily set max_target_length for training. + max_target_length = data_args.max_target_length + padding = "max_length" if data_args.pad_to_max_length else False + + if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): + logger.warning( + "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" + f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" + ) + + + question_column = data_args.question_column + context_column = data_args.context_column + answer_column = data_args.answer_column + # import random + if data_args.max_source_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length) + # Data collator + label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + if data_args.pad_to_max_length: + data_collator = default_data_collator + else: + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=8 if training_args.fp16 else None, + ) + + #start + train_dataloaders = {} + eval_dataloaders = {} + replay_dataloaders = {} + all_replay = None + + for ds_name in ['squad','narrativeqa','race','newsqa','quoref','drop','nqopen','openbookqa','mctest','social_iqa','dream']: + eval_dataloaders[ds_name] = (load_from_disk("./oursfinallong/{}-eval.hf".format(ds_name)),load_from_disk("./oursfinallong/{}-evalex.hf".format(ds_name))) + + + pre_tasks = [] + pre_general = [] + pre_test = [] + max_length = ( + training_args.generation_max_length + if training_args.generation_max_length is not None + else data_args.val_max_target_length + ) + num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams + + task_sequence = ['squad','newsqa','narrativeqa','nqopen','race','openbookqa','mctest','social_iqa'] + # task_sequence = ["woz.en","srl","sst","wikisql","squad"] + fileout = open("diana_log.txt",'w') + all_replay = None + all_features = [] + all_ids = [] + cluster_num=0 + for cur_dataset in task_sequence: + p=200 + if p>len(load_from_disk("./oursfinallong/{}-train.hf".format(cur_dataset))): + p = len(load_from_disk("./oursfinallong/{}-train.hf".format(cur_dataset))) + + trainds = load_from_disk("./oursfinallong/{}-train.hf".format(cur_dataset)) + cluster_num+=5 + pre_tasks.append(cur_dataset) + if cur_dataset==task_sequence[-1]: + pre_tasks.extend(["drop","quoref","dream"]) + data_args.dataset_name = cur_dataset + logger.info("current_dataset:"+cur_dataset) + + + training_args.do_train = True + training_args.to_eval = False + metric = load_metric(dataset_name_to_metric[cur_dataset]) + if all_replay is not None: + fused = datasets.concatenate_datasets([all_replay,trainds]) + else: + fused = trainds + training_args.num_train_epochs = 5 + model.encoder.reset_train_count() + + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=fused, + eval_dataset=None, + eval_examples=None, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + train_result = trainer.train() + + if training_args.local_rank<=0: + save_set,features = save_load_diverse_sample(model,trainds) + + if all_replay is None: + all_replay = save_set + else: + all_replay = datasets.concatenate_datasets([all_replay,save_set]) + + if all_features==[]: + all_features=features + else: + all_features.extend(features) + np.save("./all_features.npy",np.array(all_features)) + + all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset)) + + + + if training_args.local_rank!=-1: + torch.distributed.barrier() + all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset)) + all_ids.extend([task2id[cur_dataset]]*50) + all_features=np.load("./all_features.npy").tolist() + + + + model.encoder.reset_train_count() + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=all_replay, + eval_dataset=None, + eval_examples=None, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + train_result = trainer.train() + + model.encoder.add_negs(all_ids,all_features) + + for pre_dataset in pre_tasks: + + + data_args.dataset_name = pre_dataset + + metric = load_metric(dataset_name_to_metric[pre_dataset]) + eval_dataset,eval_examples = eval_dataloaders[pre_dataset] + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=None, + eval_dataset=eval_dataset, + eval_examples=eval_examples, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + + + + torch.cuda.empty_cache() + logger.info("*** Evaluate:{} ***".format(data_args.dataset_name)) + + max_length, num_beams, ignore_keys_for_eval = None, None, None + metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval") + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + if training_args.local_rank<=0: + try: + print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout) + print(metrics,file=fileout) + except: + pass + + + + + + + languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None] + if len(languages) > 0: + kwargs["language"] = languages + return None + + +# - + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/diana/downstream/diana.sh b/diana/downstream/diana.sh new file mode 100644 index 0000000..6eb88bc --- /dev/null +++ b/diana/downstream/diana.sh @@ -0,0 +1,53 @@ +function fulldata_hfdata() { +SAVING_PATH=$1 +PRETRAIN_MODEL_PATH=$2 +TRAIN_BATCH_SIZE=$3 +EVAL_BATCH_SIZE=$4 + + + + + +mkdir -p ${SAVING_PATH} + +python -m torch.distributed.launch --nproc_per_node=4 diana.py \ + --model_name_or_path ${PRETRAIN_MODEL_PATH} \ + --output_dir ${SAVING_PATH} \ + --dataset_name squad \ + --do_train \ + --do_eval \ + --per_device_train_batch_size ${TRAIN_BATCH_SIZE} \ + --per_device_eval_batch_size ${EVAL_BATCH_SIZE} \ + --overwrite_output_dir \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 5 \ + --warmup_ratio 0.1 \ + --logging_steps 5 \ + --learning_rate 1e-4 \ + --predict_with_generate \ + --num_beams 4 \ + --save_strategy no \ + --evaluation_strategy no \ + --weight_decay 1e-2 \ + --max_source_length 512 \ + --label_smoothing_factor 0.1 \ + --do_lowercase True \ + --load_best_model_at_end True \ + --greater_is_better True \ + --save_total_limit 10 \ + --ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log +} + +SAVING_PATH=$1 +TRAIN_BATCH_SIZE=$2 +EVAL_BATCH_SIZE=$3 + + + + + +MODE=try + + +SAVING_PATH=${SAVING_PATH}/lifelong/${MODE} +fulldata_hfdata ${SAVING_PATH} ./t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE} diff --git a/diana/downstream/dianawometa.py b/diana/downstream/dianawometa.py new file mode 100644 index 0000000..329f293 --- /dev/null +++ b/diana/downstream/dianawometa.py @@ -0,0 +1,798 @@ +#!/usr/bin/env python +# coding=utf-8 +import logging +import os +os.environ['NCCL_ASYNC_ERROR_HANDLING']='1' +import torch +import copy,random +import sys +import json +from dataclasses import dataclass, field +from typing import Optional +from typing import List, Optional, Tuple +from sklearn.cluster import KMeans +# from data import QAData +sys.path.append("../") +from models.metat5nometa import T5ForConditionalGeneration as PromptT5 +from downstream.dataset_processors import * +from downstream.l2ptrainer import QuestionAnsweringTrainer +import datasets +import numpy as np +from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets +import os + +from functools import partial +os.environ["WANDB_DISABLED"] = "true" +import transformers +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + M2M100Tokenizer, + MBart50Tokenizer, + EvalPrediction, + MBart50TokenizerFast, + MBartTokenizer, + MBartTokenizerFast, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + default_data_collator, + set_seed, +) +from pathlib import Path +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +logger = logging.getLogger(__name__) + +# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes. +MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer] +format2id = {'extractive':0,'abstractive':1,'multichoice':2} +format2dataset = { + 'extractive':['squad','extractive','newsqa','quoref'], + 'abstractive':['narrativeqa','abstractive','nqopen','drop'], + 'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream'] +} +seed_datasets = ['race','narrativeqa','squad'] +dataset2format= {} + +for k,vs in format2dataset.items(): + for v in vs: + dataset2format[v] = k + +task2id = {'squad': 0, 'narrativeqa': 1, 'race': 2, 'newsqa':3,'quoref':4,'drop':5,'nqopen':6,'openbookqa':7,'mctest':8,'social_iqa':9,'dream':10} + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + fix_t5: Optional[bool] = field( + default=False, + metadata={"help": "whether to fix the main parameters of t5"} + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + source_lang: str = field(default=None, metadata={"help": "Source language id for translation."}) + target_lang: str = field(default=None, metadata={"help": "Target language id for translation."}) + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."}) + validation_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on " + "a jsonlines file." + }, + ) + test_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + after_train: Optional[str] = field(default=None) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + }, + ) + max_seq_length: int = field( + default=384, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + source_prefix: Optional[str] = field( + default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + forced_bos_token: Optional[str] = field( + default=None, + metadata={ + "help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`." + "Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token " + "needs to be the target language token.(Usually it is the target language token)" + }, + ) + do_lowercase: bool = field( + default=False, + metadata={ + 'help':'Whether to process input into lowercase' + } + ) + append_another_bos: bool = field( + default=False, + metadata={ + 'help':'Whether to add another bos' + } + ) + context_column: Optional[str] = field( + default="context", + metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."}, + ) + question_column: Optional[str] = field( + default="question", + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + answer_column: Optional[str] = field( + default="answers", + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + max_task_num: Optional[int] = field( + default=30, + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + qa_task_type_num: Optional[int] = field( + default=4, + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + prompt_number: Optional[int] = field( + default=100, + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + add_task_prompt: Optional[bool] = field( + default=False, + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + reload_from_trained_prompt: Optional[bool] = field( + default=False, + metadata={"help": "whether to reload prompt from trained prompt"}, + ) + trained_prompt_path:Optional[str] = field( + default = None, + metadata={ + 'help':'the path storing trained prompt embedding' + } + ) + load_from_format_task_id: Optional[bool] = field( + default=False, + metadata={"help": "whether to reload prompt from format-corresponding task prompt"} + ) + + + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + # elif self.source_lang is None or self.target_lang is None: + # raise ValueError("Need to specify the source language and the target language.") + + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + # assert extension == "json", "`train_file` should be a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + # assert extension == "json", "`validation_file` should be a json file." + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + + +# + +def main(): + + + + def preprocess_validation_function(examples): + preprocess_fn = dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column) + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + # Setup the tokenizer for targets + model_inputs["example_id"] = [] + + for i in range(len(model_inputs["input_ids"])): + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = i #sample_mapping[i] + model_inputs["example_id"].append(examples["id"][sample_index]) + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + model_inputs["labels"] = labels["input_ids"] + + if True: + pass + # logger.info(f'Loading task {data_args.dataset_name} prompt') + # format_id = format2id[dataset2format[data_args.dataset_name]] + # task_id = task2id[data_args.dataset_name] + # format_prompt_ids = [- (i + 1) for i in range(format_id * data_args.prompt_number, ( + # format_id + 1) * data_args.prompt_number)] # list(range(-(format_id * data_args.prompt_number+1), -((format_id + 1) * data_args.prompt_number+1))) + # task_prompt_id_start = len(format2id.keys()) * data_args.prompt_number + # logger.info('Prompt ids {}: {}'.format(task_prompt_id_start + task_id * data_args.prompt_number, + # task_prompt_id_start + (task_id + 1) * data_args.prompt_number)) + # task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * data_args.prompt_number, + # task_prompt_id_start + (task_id + 1) * data_args.prompt_number)] + # input_ids = copy.deepcopy( + # [format_prompt_ids + task_prompt_ids + input_ids for input_ids in model_inputs['input_ids']]) + #input_ids = copy.deepcopy([format_prompt_ids + input_ids for input_ids in model_inputs['input_ids']]) + # model_inputs['input_ids'] = input_ids + # model_inputs['attention_mask'] = [[1] * data_args.prompt_number * 2 + attention_mask for attention_mask in + # model_inputs['attention_mask']] + + # input_ids = copy.deepcopy([input]) + return model_inputs + def compute_metrics(p: EvalPrediction): + return metric.compute(predictions=p.predictions, references=p.label_ids) + + + def save_prompt_embedding(model,path): + prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight'] + save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id} + prompt_path = os.path.join(path,'prompt_embedding_info') + torch.save(save_prompt_info,prompt_path) + logger.info(f'Saving prompt embedding information to {prompt_path}') + + + + def preprocess_function(examples): + preprocess_fn = dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column) + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + def save_load_diverse_sample(model,trainset): + features = [] + device='cuda:0' + with torch.no_grad(): + sub_set = trainset.select(range(50)) + + for idx,item in enumerate(sub_set): + query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device), + attention_mask=torch.tensor([item['attention_mask']]).long().to(device), + return_dict=True) + features.append(query.detach().cpu().numpy()) + + + + return sub_set,features + + + + + + def dataset_name_to_func(dataset_name): + mapping = { + 'squad': preprocess_sqaud_batch, + 'squad_v2': preprocess_sqaud_batch, + 'boolq': preprocess_boolq_batch, + 'narrativeqa': preprocess_narrativeqa_batch, + 'race': preprocess_race_batch, + 'newsqa': preprocess_newsqa_batch, + 'quoref': preprocess_sqaud_batch, + 'ropes': preprocess_ropes_batch, + 'drop': preprocess_drop_batch, + 'nqopen': preprocess_sqaud_abstractive_batch, + # 'multirc': preprocess_boolq_batch, + 'boolq_np': preprocess_boolq_batch, + 'openbookqa': preprocess_openbookqa_batch, + 'mctest': preprocess_race_batch, + 'social_iqa': preprocess_social_iqa_batch, + 'dream': preprocess_dream_batch, + } + return mapping[dataset_name] + + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + dataset_name_to_metric = { + 'squad': 'metric/squad_v1_local/squad_v1_local.py', + 'squad_v2': 'metric/squad_v2_local/squad_v2_local.py', + 'newsqa': 'metric/squad_v2_local/squad_v2_local.py', + 'boolq': 'metric/accuracy.py', + 'narrativeqa': 'metric/rouge_local/rouge_metric.py', + 'race': 'metric/accuracy.py', + 'quoref': 'metric/squad_v1_local/squad_v1_local.py', + 'ropes': 'metric/squad_v1_local/squad_v1_local.py', + 'drop': 'metric/squad_v1_local/squad_v1_local.py', + 'nqopen': 'metric/squad_v1_local/squad_v1_local.py', + 'boolq_np': 'metric/accuracy.py', + 'openbookqa': 'metric/accuracy.py', + 'mctest': 'metric/accuracy.py', + 'social_iqa': 'metric/accuracy.py', + 'dream': 'metric/accuracy.py', + } + + + + + + + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + if data_args.source_prefix is None and model_args.model_name_or_path in [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + ]: + logger.warning( + "You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with " + "`--source_prefix 'translate English to German: ' `" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]', + # '[OPTIONS]']) + tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]'] + special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]', + '[OPTIONS]']} + + tokenizer.add_tokens(tokens_to_add) + num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) + added_tokens = tokenizer.get_added_vocab() + logger.info('Added tokens: {}'.format(added_tokens)) + + model = PromptT5.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + # task_num = data_args.max_task_num, + # prompt_num = data_args.prompt_number, + # format_num = data_args.qa_task_type_num, + # add_task_prompt = False + ) + + model.resize_token_embeddings(len(tokenizer)) + + + #reload format specific task-prompt for newly involved task +#format_prompts###task_promptsf + + data_args.reload_from_trained_prompt = False#@ + data_args.load_from_format_task_id = False#@ + ### before pretrain come !!!!!! + + if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt: + task_start_id = data_args.prompt_number * len(format2dataset.keys()) + task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number + format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number + model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:] + logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}') + # print(dataset2format[data_args.dataset_name]) + # print(data_args.dataset_name) + elif data_args.reload_from_trained_prompt: + assert data_args.trained_prompt_path,'Must specify the path of stored prompt' + prompt_info = torch.load(data_args.trained_prompt_path) + assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}' + assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match' + task_start_id = data_args.prompt_number * len(format2dataset.keys()) + + task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number + logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number)) + # assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0 + model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] + format_id = format2id[dataset2format[data_args.dataset_name]] + model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] + logger.info( + f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}') + + # Set decoder_start_token_id + if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): + if isinstance(tokenizer, MBartTokenizer): + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang] + else: + model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang) + + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + + + if training_args.local_rank == -1 or training_args.no_cuda: + device = torch.device("cuda") + n_gpu = torch.cuda.device_count() + + + + + + # Temporarily set max_target_length for training. + max_target_length = data_args.max_target_length + padding = "max_length" if data_args.pad_to_max_length else False + + if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): + logger.warning( + "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" + f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" + ) + + + question_column = data_args.question_column + context_column = data_args.context_column + answer_column = data_args.answer_column + # import random + if data_args.max_source_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length) + # Data collator + label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + if data_args.pad_to_max_length: + data_collator = default_data_collator + else: + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=8 if training_args.fp16 else None, + ) + + #start + train_dataloaders = {} + eval_dataloaders = {} + replay_dataloaders = {} + all_replay = None + + for ds_name in ['squad','narrativeqa','race','newsqa','quoref','drop','nqopen','openbookqa','mctest','social_iqa','dream']: + eval_dataloaders[ds_name] = (load_from_disk("./oursnometa/{}-eval.hf".format(ds_name)),load_from_disk("./oursnometa/{}-evalex.hf".format(ds_name))) + + + pre_tasks = [] + pre_general = [] + pre_test = [] + max_length = ( + training_args.generation_max_length + if training_args.generation_max_length is not None + else data_args.val_max_target_length + ) + num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams + + task_sequence = ['squad','newsqa','narrativeqa','nqopen','race','openbookqa','mctest','social_iqa'] + fileout = open("diana_log.txt",'w') + all_replay = None + all_features = [] + all_ids = [] + cluster_num=0 + for cur_dataset in task_sequence: + + trainds = load_from_disk("./oursnometa/{}-train.hf".format(cur_dataset)).select(range(1000)) + cluster_num+=5 + pre_tasks.append(cur_dataset) + if cur_dataset==task_sequence[-1]: + pre_tasks.extend(["drop","quoref","dream"]) + data_args.dataset_name = cur_dataset + logger.info("current_dataset:"+cur_dataset) + + + training_args.do_train = True + training_args.to_eval = False + metric = load_metric(dataset_name_to_metric[cur_dataset]) + if all_replay is not None: + fused = datasets.concatenate_datasets([all_replay,trainds]) + else: + fused = trainds + training_args.num_train_epochs = 5 + model.encoder.reset_train_count() + + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=fused, + eval_dataset=None, + eval_examples=None, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + train_result = trainer.train() + + if training_args.local_rank<=0: + try: + save_set,features = save_load_diverse_sample(model,trainds) + + if all_replay is None: + all_replay = save_set + else: + all_replay = datasets.concatenate_datasets([all_replay,save_set]) + + if all_features==[]: + all_features=features + else: + all_features.extend(features) + np.save("./all_features.npy",np.array(all_features)) + + all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset)) + except: + pass + + + if training_args.local_rank!=-1: + torch.distributed.barrier() + all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset)) + all_ids.extend([task2id[cur_dataset]]*50) + all_features=np.load("./all_features.npy").tolist() + + + + model.encoder.reset_train_count() + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=all_replay, + eval_dataset=None, + eval_examples=None, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + train_result = trainer.train() + + model.encoder.add_negs(all_ids,all_features) + + for pre_dataset in pre_tasks: + + + data_args.dataset_name = pre_dataset + + metric = load_metric(dataset_name_to_metric[pre_dataset]) + eval_dataset,eval_examples = eval_dataloaders[pre_dataset] + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=None, + eval_dataset=eval_dataset, + eval_examples=eval_examples, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + + + + torch.cuda.empty_cache() + logger.info("*** Evaluate:{} ***".format(data_args.dataset_name)) + + max_length, num_beams, ignore_keys_for_eval = None, None, None + metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval") + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + if training_args.local_rank<=0: + try: + print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout) + print(metrics,file=fileout) + except: + pass + + + + + + + languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None] + if len(languages) > 0: + kwargs["language"] = languages + return None + + +# - + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/diana/downstream/dianawometa.sh b/diana/downstream/dianawometa.sh new file mode 100644 index 0000000..50a03b5 --- /dev/null +++ b/diana/downstream/dianawometa.sh @@ -0,0 +1,53 @@ +function fulldata_hfdata() { +SAVING_PATH=$1 +PRETRAIN_MODEL_PATH=$2 +TRAIN_BATCH_SIZE=$3 +EVAL_BATCH_SIZE=$4 + + + + + +mkdir -p ${SAVING_PATH} + +python -m torch.distributed.launch --nproc_per_node=4 dianawometa.py \ + --model_name_or_path ${PRETRAIN_MODEL_PATH} \ + --output_dir ${SAVING_PATH} \ + --dataset_name squad \ + --do_train \ + --do_eval \ + --per_device_train_batch_size ${TRAIN_BATCH_SIZE} \ + --per_device_eval_batch_size ${EVAL_BATCH_SIZE} \ + --overwrite_output_dir \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 5 \ + --warmup_ratio 0.1 \ + --logging_steps 5 \ + --learning_rate 1e-4 \ + --predict_with_generate \ + --num_beams 4 \ + --save_strategy no \ + --evaluation_strategy no \ + --weight_decay 1e-2 \ + --max_source_length 512 \ + --label_smoothing_factor 0.1 \ + --do_lowercase True \ + --load_best_model_at_end True \ + --greater_is_better True \ + --save_total_limit 10 \ + --ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log +} + +SAVING_PATH=$1 +TRAIN_BATCH_SIZE=$2 +EVAL_BATCH_SIZE=$3 + + + + + +MODE=try + + +SAVING_PATH=${SAVING_PATH}/lifelong/${MODE} +fulldata_hfdata ${SAVING_PATH} ./t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE} diff --git a/diana/downstream/dianawotask.py b/diana/downstream/dianawotask.py new file mode 100644 index 0000000..e5fdaa0 --- /dev/null +++ b/diana/downstream/dianawotask.py @@ -0,0 +1,827 @@ +#!/usr/bin/env python +# coding=utf-8 +import logging +import os +os.environ['NCCL_ASYNC_ERROR_HANDLING']='1' +import torch +import copy,random +import sys +import json +from dataclasses import dataclass, field +from typing import Optional +from typing import List, Optional, Tuple +from sklearn.cluster import KMeans +# from data import QAData +sys.path.append("../") +from models.metat5notask import T5ForConditionalGeneration as PromptT5 +from downstream.dataset_processors import * +from downstream.l2ptrainer import QuestionAnsweringTrainer +import datasets +import numpy as np +from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets +import os + +from functools import partial +os.environ["WANDB_DISABLED"] = "true" +import transformers +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + M2M100Tokenizer, + MBart50Tokenizer, + EvalPrediction, + MBart50TokenizerFast, + MBartTokenizer, + MBartTokenizerFast, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + default_data_collator, + set_seed, +) +from pathlib import Path +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +logger = logging.getLogger(__name__) + +# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes. +MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer] +format2id = {'extractive':0,'abstractive':1,'multichoice':2} +format2dataset = { + 'extractive':['squad','extractive','newsqa','quoref'], + 'abstractive':['narrativeqa','abstractive','nqopen','drop'], + 'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream'] +} +seed_datasets = ['race','narrativeqa','squad'] +dataset2format= {} + +for k,vs in format2dataset.items(): + for v in vs: + dataset2format[v] = k + +task2id = {'squad': 0, 'narrativeqa': 1, 'race': 2, 'newsqa':3,'quoref':4,'drop':5,'nqopen':6,'openbookqa':7,'mctest':8,'social_iqa':9,'dream':10} + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + fix_t5: Optional[bool] = field( + default=False, + metadata={"help": "whether to fix the main parameters of t5"} + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + source_lang: str = field(default=None, metadata={"help": "Source language id for translation."}) + target_lang: str = field(default=None, metadata={"help": "Target language id for translation."}) + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."}) + validation_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on " + "a jsonlines file." + }, + ) + test_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + after_train: Optional[str] = field(default=None) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + }, + ) + max_seq_length: int = field( + default=384, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + source_prefix: Optional[str] = field( + default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + forced_bos_token: Optional[str] = field( + default=None, + metadata={ + "help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`." + "Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token " + "needs to be the target language token.(Usually it is the target language token)" + }, + ) + do_lowercase: bool = field( + default=False, + metadata={ + 'help':'Whether to process input into lowercase' + } + ) + append_another_bos: bool = field( + default=False, + metadata={ + 'help':'Whether to add another bos' + } + ) + context_column: Optional[str] = field( + default="context", + metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."}, + ) + question_column: Optional[str] = field( + default="question", + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + answer_column: Optional[str] = field( + default="answers", + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + max_task_num: Optional[int] = field( + default=30, + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + qa_task_type_num: Optional[int] = field( + default=4, + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + prompt_number: Optional[int] = field( + default=100, + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + add_task_prompt: Optional[bool] = field( + default=False, + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + reload_from_trained_prompt: Optional[bool] = field( + default=False, + metadata={"help": "whether to reload prompt from trained prompt"}, + ) + trained_prompt_path:Optional[str] = field( + default = None, + metadata={ + 'help':'the path storing trained prompt embedding' + } + ) + load_from_format_task_id: Optional[bool] = field( + default=False, + metadata={"help": "whether to reload prompt from format-corresponding task prompt"} + ) + + + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + # elif self.source_lang is None or self.target_lang is None: + # raise ValueError("Need to specify the source language and the target language.") + + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + # assert extension == "json", "`train_file` should be a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + # assert extension == "json", "`validation_file` should be a json file." + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + + +# + +def main(): + + + + def preprocess_validation_function(examples): + preprocess_fn = dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column) + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + # Setup the tokenizer for targets + model_inputs["example_id"] = [] + + for i in range(len(model_inputs["input_ids"])): + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = i #sample_mapping[i] + model_inputs["example_id"].append(examples["id"][sample_index]) + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + model_inputs["labels"] = labels["input_ids"] + + if True: + pass + # logger.info(f'Loading task {data_args.dataset_name} prompt') + # format_id = format2id[dataset2format[data_args.dataset_name]] + # task_id = task2id[data_args.dataset_name] + # format_prompt_ids = [- (i + 1) for i in range(format_id * data_args.prompt_number, ( + # format_id + 1) * data_args.prompt_number)] # list(range(-(format_id * data_args.prompt_number+1), -((format_id + 1) * data_args.prompt_number+1))) + # task_prompt_id_start = len(format2id.keys()) * data_args.prompt_number + # logger.info('Prompt ids {}: {}'.format(task_prompt_id_start + task_id * data_args.prompt_number, + # task_prompt_id_start + (task_id + 1) * data_args.prompt_number)) + # task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * data_args.prompt_number, + # task_prompt_id_start + (task_id + 1) * data_args.prompt_number)] + # input_ids = copy.deepcopy( + # [format_prompt_ids + task_prompt_ids + input_ids for input_ids in model_inputs['input_ids']]) + #input_ids = copy.deepcopy([format_prompt_ids + input_ids for input_ids in model_inputs['input_ids']]) + # model_inputs['input_ids'] = input_ids + # model_inputs['attention_mask'] = [[1] * data_args.prompt_number * 2 + attention_mask for attention_mask in + # model_inputs['attention_mask']] + + # input_ids = copy.deepcopy([input]) + return model_inputs + def compute_metrics(p: EvalPrediction): + return metric.compute(predictions=p.predictions, references=p.label_ids) + + + def save_prompt_embedding(model,path): + prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight'] + save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id} + prompt_path = os.path.join(path,'prompt_embedding_info') + torch.save(save_prompt_info,prompt_path) + logger.info(f'Saving prompt embedding information to {prompt_path}') + + + + def preprocess_function(examples): + preprocess_fn = dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column) + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + def save_load_diverse_sample(model,trainset): + with torch.no_grad(): + device = 'cuda:0' + search_upbound = len(trainset)//4 + query_idxs = [None]*30 + keys = model.encoder.domain_keys + for idx,item in enumerate(trainset.select(range(search_upbound))): + query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device), + attention_mask=torch.tensor([item['attention_mask']]).long().to(device), + return_dict=True) + result = torch.matmul(query,keys.t()) + result = torch.topk(result,5).indices[0].cpu().numpy().tolist() + key_sel = None + for key_idx in result: + if query_idxs[key_idx] is None or len(query_idxs[key_idx])<3: + key_sel = key_idx + break + if key_sel is not None: + if query_idxs[key_sel] is None: + query_idxs[key_sel] = [idx] + else: + query_idxs[key_sel].append(idx) + total_idxs = [] + for item in query_idxs: + try: + total_idxs.extend(item[:3]) + except: + total_idxs.extend(random.sample(list(range(search_upbound,len(trainset))),3)) + total_idxs = list(set(total_idxs)) + total_idxs = random.sample(total_idxs,50) + sub_set = trainset.select(total_idxs) + + features = [] + for idx,item in enumerate(sub_set): + query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device), + attention_mask=torch.tensor([item['attention_mask']]).long().to(device), + return_dict=True) + features.append(query.detach().cpu().numpy()) + + + + return sub_set,features + + + + + + def dataset_name_to_func(dataset_name): + mapping = { + 'squad': preprocess_sqaud_batch, + 'squad_v2': preprocess_sqaud_batch, + 'boolq': preprocess_boolq_batch, + 'narrativeqa': preprocess_narrativeqa_batch, + 'race': preprocess_race_batch, + 'newsqa': preprocess_newsqa_batch, + 'quoref': preprocess_sqaud_batch, + 'ropes': preprocess_ropes_batch, + 'drop': preprocess_drop_batch, + 'nqopen': preprocess_sqaud_abstractive_batch, + # 'multirc': preprocess_boolq_batch, + 'boolq_np': preprocess_boolq_batch, + 'openbookqa': preprocess_openbookqa_batch, + 'mctest': preprocess_race_batch, + 'social_iqa': preprocess_social_iqa_batch, + 'dream': preprocess_dream_batch, + } + return mapping[dataset_name] + + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + dataset_name_to_metric = { + 'squad': 'metric/squad_v1_local/squad_v1_local.py', + 'squad_v2': 'metric/squad_v2_local/squad_v2_local.py', + 'newsqa': 'metric/squad_v2_local/squad_v2_local.py', + 'boolq': 'metric/accuracy.py', + 'narrativeqa': 'metric/rouge_local/rouge_metric.py', + 'race': 'metric/accuracy.py', + 'quoref': 'metric/squad_v1_local/squad_v1_local.py', + 'ropes': 'metric/squad_v1_local/squad_v1_local.py', + 'drop': 'metric/squad_v1_local/squad_v1_local.py', + 'nqopen': 'metric/squad_v1_local/squad_v1_local.py', + 'boolq_np': 'metric/accuracy.py', + 'openbookqa': 'metric/accuracy.py', + 'mctest': 'metric/accuracy.py', + 'social_iqa': 'metric/accuracy.py', + 'dream': 'metric/accuracy.py', + } + + + + + + + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + if data_args.source_prefix is None and model_args.model_name_or_path in [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + ]: + logger.warning( + "You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with " + "`--source_prefix 'translate English to German: ' `" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]', + # '[OPTIONS]']) + tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]'] + special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]', + '[OPTIONS]']} + + tokenizer.add_tokens(tokens_to_add) + num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) + added_tokens = tokenizer.get_added_vocab() + logger.info('Added tokens: {}'.format(added_tokens)) + + model = PromptT5.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + # task_num = data_args.max_task_num, + # prompt_num = data_args.prompt_number, + # format_num = data_args.qa_task_type_num, + # add_task_prompt = False + ) + + model.resize_token_embeddings(len(tokenizer)) + + + #reload format specific task-prompt for newly involved task +#format_prompts###task_promptsf + + data_args.reload_from_trained_prompt = False#@ + data_args.load_from_format_task_id = False#@ + ### before pretrain come !!!!!! + + if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt: + task_start_id = data_args.prompt_number * len(format2dataset.keys()) + task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number + format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number + model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:] + logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}') + # print(dataset2format[data_args.dataset_name]) + # print(data_args.dataset_name) + elif data_args.reload_from_trained_prompt: + assert data_args.trained_prompt_path,'Must specify the path of stored prompt' + prompt_info = torch.load(data_args.trained_prompt_path) + assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}' + assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match' + task_start_id = data_args.prompt_number * len(format2dataset.keys()) + + task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number + logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number)) + # assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0 + model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] + format_id = format2id[dataset2format[data_args.dataset_name]] + model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] + logger.info( + f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}') + + # Set decoder_start_token_id + if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): + if isinstance(tokenizer, MBartTokenizer): + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang] + else: + model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang) + + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + + + if training_args.local_rank == -1 or training_args.no_cuda: + device = torch.device("cuda") + n_gpu = torch.cuda.device_count() + + + + + + # Temporarily set max_target_length for training. + max_target_length = data_args.max_target_length + padding = "max_length" if data_args.pad_to_max_length else False + + if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): + logger.warning( + "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" + f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" + ) + + + question_column = data_args.question_column + context_column = data_args.context_column + answer_column = data_args.answer_column + # import random + if data_args.max_source_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length) + # Data collator + label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + if data_args.pad_to_max_length: + data_collator = default_data_collator + else: + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=8 if training_args.fp16 else None, + ) + + #start + train_dataloaders = {} + eval_dataloaders = {} + replay_dataloaders = {} + all_replay = None + + for ds_name in ['squad','narrativeqa','race','newsqa','quoref','drop','nqopen','openbookqa','mctest','social_iqa','dream']: + eval_dataloaders[ds_name] = (load_from_disk("./oursnotask/{}-eval.hf".format(ds_name)),load_from_disk("./oursnotask/{}-evalex.hf".format(ds_name))) + + + pre_tasks = [] + pre_general = [] + pre_test = [] + max_length = ( + training_args.generation_max_length + if training_args.generation_max_length is not None + else data_args.val_max_target_length + ) + num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams + + task_sequence = ['squad','newsqa','narrativeqa','nqopen','race','openbookqa','mctest','social_iqa'] + # task_sequence = ["woz.en","srl","sst","wikisql","squad"] + fileout = open("diana_log.txt",'w') + all_replay = None + all_features = [] + all_ids = [] + cluster_num=0 + for cur_dataset in task_sequence: + p=200 + if p>len(load_from_disk("./oursnotask/{}-train.hf".format(cur_dataset))): + p = len(load_from_disk("./oursnotask/{}-train.hf".format(cur_dataset))) + + trainds = load_from_disk("./oursnotask/{}-train.hf".format(cur_dataset)) + cluster_num+=5 + pre_tasks.append(cur_dataset) + if cur_dataset==task_sequence[-1]: + pre_tasks.extend(["drop","quoref","dream"]) + data_args.dataset_name = cur_dataset + logger.info("current_dataset:"+cur_dataset) + + + training_args.do_train = True + training_args.to_eval = False + metric = load_metric(dataset_name_to_metric[cur_dataset]) + if all_replay is not None: + fused = datasets.concatenate_datasets([all_replay,trainds]) + else: + fused = trainds + training_args.num_train_epochs = 5 + model.encoder.reset_train_count() + + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=fused, + eval_dataset=None, + eval_examples=None, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + train_result = trainer.train() + + if training_args.local_rank<=0: + save_set,features = save_load_diverse_sample(model,trainds) + + if all_replay is None: + all_replay = save_set + else: + all_replay = datasets.concatenate_datasets([all_replay,save_set]) + + if all_features==[]: + all_features=features + else: + all_features.extend(features) + np.save("./all_features.npy",np.array(all_features)) + + all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset)) + + + + if training_args.local_rank!=-1: + torch.distributed.barrier() + all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset)) + all_ids.extend([task2id[cur_dataset]]*50) + all_features=np.load("./all_features.npy").tolist() + + + + model.encoder.reset_train_count() + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=all_replay, + eval_dataset=None, + eval_examples=None, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + train_result = trainer.train() + + model.encoder.add_negs(all_ids,all_features) + + for pre_dataset in pre_tasks: + + + data_args.dataset_name = pre_dataset + + metric = load_metric(dataset_name_to_metric[pre_dataset]) + eval_dataset,eval_examples = eval_dataloaders[pre_dataset] + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=None, + eval_dataset=eval_dataset, + eval_examples=eval_examples, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + + + + torch.cuda.empty_cache() + logger.info("*** Evaluate:{} ***".format(data_args.dataset_name)) + + max_length, num_beams, ignore_keys_for_eval = None, None, None + metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval") + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + if training_args.local_rank<=0: + try: + print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout) + print(metrics,file=fileout) + except: + pass + + + + + + + languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None] + if len(languages) > 0: + kwargs["language"] = languages + return None + + +# - + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/diana/downstream/dianawotask.sh b/diana/downstream/dianawotask.sh new file mode 100644 index 0000000..a0c6fb9 --- /dev/null +++ b/diana/downstream/dianawotask.sh @@ -0,0 +1,53 @@ +function fulldata_hfdata() { +SAVING_PATH=$1 +PRETRAIN_MODEL_PATH=$2 +TRAIN_BATCH_SIZE=$3 +EVAL_BATCH_SIZE=$4 + + + + + +mkdir -p ${SAVING_PATH} + +python -m torch.distributed.launch --nproc_per_node=4 dianawotask.py \ + --model_name_or_path ${PRETRAIN_MODEL_PATH} \ + --output_dir ${SAVING_PATH} \ + --dataset_name squad \ + --do_train \ + --do_eval \ + --per_device_train_batch_size ${TRAIN_BATCH_SIZE} \ + --per_device_eval_batch_size ${EVAL_BATCH_SIZE} \ + --overwrite_output_dir \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 5 \ + --warmup_ratio 0.1 \ + --logging_steps 5 \ + --learning_rate 1e-4 \ + --predict_with_generate \ + --num_beams 4 \ + --save_strategy no \ + --evaluation_strategy no \ + --weight_decay 1e-2 \ + --max_source_length 512 \ + --label_smoothing_factor 0.1 \ + --do_lowercase True \ + --load_best_model_at_end True \ + --greater_is_better True \ + --save_total_limit 10 \ + --ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log +} + +SAVING_PATH=$1 +TRAIN_BATCH_SIZE=$2 +EVAL_BATCH_SIZE=$3 + + + + + +MODE=try + + +SAVING_PATH=${SAVING_PATH}/lifelong/${MODE} +fulldata_hfdata ${SAVING_PATH} ./t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE} diff --git a/diana/downstream/l2ptrainer.py b/diana/downstream/l2ptrainer.py new file mode 100755 index 0000000..6033d7e --- /dev/null +++ b/diana/downstream/l2ptrainer.py @@ -0,0 +1,556 @@ +from transformers import Seq2SeqTrainer, is_torch_tpu_available, EvalPrediction +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +import nltk +import datasets +import re +import os +import numpy as np +import torch +import random +from pathlib import Path +import nltk +# nltk.download('punkt') +from transformers.trainer_utils import ( + PREFIX_CHECKPOINT_DIR, + BestRun, + EvalLoopOutput, + EvalPrediction, + HPSearchBackend, + HubStrategy, + IntervalStrategy, + PredictionOutput, + ShardedDDPOption, + TrainerMemoryTracker, + TrainOutput, + default_compute_objective, + default_hp_space, + denumpify_detensorize, + get_last_checkpoint, + number_of_arguments, + set_seed, + speed_metrics, +) +import warnings +from transformers.trainer_pt_utils import ( + DistributedLengthGroupedSampler, + DistributedSamplerWithLoop, + DistributedTensorGatherer, + IterableDatasetShard, + LabelSmoother, + LengthGroupedSampler, + SequentialDistributedSampler, + ShardSampler, + distributed_broadcast_scalars, + distributed_concat, + find_batch_size, + get_parameter_names, + nested_concat, + nested_detach, + nested_numpify, + nested_truncate, + nested_xla_mesh_reduce, + reissue_pt_warnings, +) +from transformers.file_utils import ( + CONFIG_NAME, + WEIGHTS_NAME, + get_full_repo_name, + is_apex_available, + is_datasets_available, + is_in_notebook, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, + is_torch_tpu_available, +) +TRAINING_ARGS_NAME = "training_args.bin" +TRAINER_STATE_NAME = "trainer_state.json" +OPTIMIZER_NAME = "optimizer.pt" +SCHEDULER_NAME = "scheduler.pt" +SCALER_NAME = "scaler.pt" +from transformers.trainer_utils import PredictionOutput,EvalLoopOutput +if is_torch_tpu_available(): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + +def fix_buggy_characters(str): + return re.sub("[{}^\\\\`\u2047<]", " ", str) +def replace_punctuation(str): + return str.replace("\"", "").replace("'", "") +def score_string_similarity(str1, str2): + if str1 == str2: + return 3.0 # Better than perfect token match + str1 = fix_buggy_characters(replace_punctuation(str1)) + str2 = fix_buggy_characters(replace_punctuation(str2)) + if str1 == str2: + return 2.0 + if " " in str1 or " " in str2: + str1_split = str1.split(" ") + str2_split = str2.split(" ") + overlap = list(set(str1_split) & set(str2_split)) + return len(overlap) / max(len(str1_split), len(str2_split)) + else: + if str1 == str2: + return 1.0 + else: + return 0.0 + +class QuestionAnsweringTrainer(Seq2SeqTrainer): + def __init__(self, *args, tokenizer,eval_examples=None,answer_column_name='answers',dataset_name='squad', **kwargs): + super().__init__(*args, **kwargs) + self.eval_examples = eval_examples + self.answer_column_name = answer_column_name + self.dataset_name = dataset_name + self.tokenizer = tokenizer + if dataset_name in ['squad', "quoref", 'ropes', 'drop', 'nqopen']: + self.post_process_function = self._post_process_squad + elif dataset_name in ['boolq', 'multirc', 'boolq_np']: + self.post_process_function = self._post_process_boolq + elif dataset_name== 'narrativeqa': + self.post_process_function = self._post_process_narrative_qa + elif dataset_name in ['race', 'mctest', 'social_iqa']: + self.post_process_function = self._post_process_race + elif dataset_name in ["squad_v2", "newsqa"]: + self.post_process_function = self._post_process_squad_v2 + elif dataset_name in ['openbookqa']: + self.post_process_function = self._post_process_openbookqa + elif dataset_name in ['dream']: + self.post_process_function = self._post_process_dream + else: + print(f"{dataset_name} has not QuestionAnsweringTrainer.post_process_function!") + raise NotImplementedError + + def compute_loss(self, model, inputs, return_outputs=False): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + Subclass and override for custom behavior. + """ + if self.label_smoother is not None and "labels" in inputs: + labels = inputs.pop("labels") + inputs["train"]=True + else: + labels = None + outputs,sim_loss = model(**inputs) + + # print(model.encoder.) + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if labels is not None: + loss = self.label_smoother(outputs, labels) + else: + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + loss += sim_loss*0.1 + # print(sim_loss) + return (loss, outputs) if return_outputs else loss + + + def _post_process_squad( + self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval" + ,version_2_with_negative=False): + # Decode the predicted tokens. + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + + # Build a map example to its corresponding features. + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)} + predictions = {} + # Let's loop over all the examples! + for example_index, example in enumerate(examples): + # This is the index of the feature associated to the current example. + try: + feature_index = feature_per_example[example_index] + except: + print(feature_per_example) + print(example_index) + print(len(feature_per_example)) + predictions[example["id"]] = decoded_preds[feature_index] + + # Format the result to the format the metric expects. + if version_2_with_negative: + formatted_predictions = [ + {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() + ] + else: + formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] + + references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples] + return EvalPrediction(predictions=formatted_predictions, label_ids=references) + + def _post_process_squad_v2( + self, examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval"): + # Decode the predicted tokens. + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + + # Build a map example to its corresponding features. + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)} + predictions = {} + # Let's loop over all the examples! + for example_index, example in enumerate(examples): + # This is the index of the feature associated to the current example. + feature_index = feature_per_example[example_index] + predictions[example["id"]] = decoded_preds[feature_index] + + # Format the result to the format the metric expects. + formatted_predictions = [ + {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() + ] + + references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples] + return EvalPrediction(predictions=formatted_predictions, label_ids=references) + + @classmethod + def postprocess_text(cls,preds, labels): + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + return preds, labels + + def _post_process_boolq(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"): + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + # preds = [" ".join(pred) for pred in decoded_preds] + preds = decoded_preds + outputs = [] + # for pred in preds: + # if 'True' == pred: + # outputs.append(1) + # elif 'False' == pred: + # outputs.append(0) + # else: + # outputs.append(0) + # references = [1 if ex['answer'] else 0 for ex in examples] + + references = [] + for pred, ref in zip(preds, examples): + ans = ref['answer'] + references.append(1 if ans else 0) + if pred.lower() in ['true', 'false']: + outputs.append(1 if pred.lower() == 'true' else 0) + else: + # generate a wrong prediction if pred is not true/false + outputs.append(0 if ans else 1) + + assert(len(references)==len(outputs)) + formatted_predictions = [] + return EvalPrediction(predictions=outputs, label_ids=references) + + def _post_process_narrative_qa(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"): + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + # preds = [" ".join(pred) for pred in decoded_preds] + preds = decoded_preds + references = [exp['answers'][0]['text'] for exp in examples] + formatted_predictions = [] + preds, references = self.postprocess_text(preds,references) + assert(len(preds)==len(references)) + return EvalPrediction(predictions=preds, label_ids=references) + + def _post_process_drop(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"): + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + # preds = [" ".join(pred) for pred in decoded_preds] + preds = decoded_preds + references = [exp['answers'][0]['text'] for exp in examples] + formatted_predictions = [] + preds, references = self.postprocess_text(preds, references) + assert (len(preds) == len(references)) + return EvalPrediction(predictions=preds, label_ids=references) + + def _post_process_race(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"): + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + references = [{"answer": ex['answer'], 'options':ex['options']} for ex in examples] + assert(len(references)==len(decoded_preds)) + gold_ids , pred_ids = [],[] + for prediction, reference in zip(decoded_preds, references): + # reference = json.loads(reference) + gold = int(ord(reference['answer'].strip()) - ord('A')) + options = reference['options'] + prediction = prediction.replace("\n", "").strip() + options = [opt.strip() for opt in options if len(opt) > 0] + # print('options',options,type(options)) + # print('prediction',prediction) + # print('answer',gold) + scores = [score_string_similarity(opt, prediction) for opt in options] + max_idx = np.argmax(scores) + gold_ids.append(gold) + pred_ids.append(max_idx) + # selected_ans = chr(ord('A') + max_idx) + # print(len(references),len(decoded_preds)) + + return EvalPrediction(predictions=pred_ids,label_ids = gold_ids) + + def _post_process_openbookqa(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"): + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + references = [{"answer": ex['answerKey'], 'options': ex['choices']['text']} for ex in examples] + assert len(references) == len(decoded_preds) + gold_ids, pred_ids = [], [] + for prediction, reference in zip(decoded_preds, references): + # reference = json.loads(reference) + gold = int(ord(reference['answer'].strip()) - ord('A')) + options = reference['options'] + prediction = prediction.replace("\n", "").strip() + options = [opt.strip() for opt in options if len(opt) > 0] + scores = [score_string_similarity(opt, prediction) for opt in options] + max_idx = np.argmax(scores) + gold_ids.append(gold) + pred_ids.append(max_idx) + return EvalPrediction(predictions=pred_ids,label_ids = gold_ids) + + def _post_process_dream(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"): + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + references = [{"answer": ex['answer'], 'options': ex['choice']} for ex in examples] + assert(len(references)==len(decoded_preds)) + gold_ids, pred_ids = [],[] + for prediction, reference in zip(decoded_preds, references): + gold = reference['options'].index(reference['answer']) + options = reference['options'] + prediction = prediction.replace("\n", "").strip() + options = [opt.strip() for opt in options if len(opt) > 0] + scores = [score_string_similarity(opt, prediction) for opt in options] + max_idx = np.argmax(scores) + gold_ids.append(gold) + pred_ids.append(max_idx) + return EvalPrediction(predictions=pred_ids, label_ids = gold_ids) + + def evaluate(self, eval_dataset=None, eval_examples=None, tokenizer=None,ignore_keys=None, metric_key_prefix: str = "eval", + max_length: Optional[int] = None,num_beams: Optional[int] = None): + self._memory_tracker.start() + self._max_length = max_length if max_length is not None else self.args.generation_max_length + self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams + eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset + eval_dataloader = self.get_eval_dataloader(eval_dataset) + eval_examples = self.eval_examples if eval_examples is None else eval_examples + + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + self.compute_metrics = None + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + try: + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + ) + finally: + self.compute_metrics = compute_metrics + + if self.post_process_function is not None and self.compute_metrics is not None: + eval_preds = self.post_process_function(eval_examples, eval_dataset, output,tokenizer) + metrics = self.compute_metrics(eval_preds) + if self.dataset_name=='narrativeqa': + metrics = {key: value * 100 for key, value in metrics.items()} + metrics = {k: round(v, 4) for k, v in metrics.items()} + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + print(metrics) + else: + metrics = {} + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) + self._memory_tracker.stop_and_update_metrics(metrics) + return metrics + + + def _save_checkpoint(self, model, trial, metrics=None): + # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we + # want to save except FullyShardedDDP. + # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" + + # Save model checkpoint + save_only_best = True + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + if self.hp_search_backend is not None and trial is not None: + if self.hp_search_backend == HPSearchBackend.OPTUNA: + run_id = trial.number + elif self.hp_search_backend == HPSearchBackend.RAY: + from ray import tune + + run_id = tune.get_trial_id() + elif self.hp_search_backend == HPSearchBackend.SIGOPT: + run_id = trial.id + run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" + run_dir = os.path.join(self.args.output_dir, run_name) + else: + run_dir = self.args.output_dir + self.store_flos() + + output_dir = os.path.join(run_dir, checkpoint_folder) + if not save_only_best: + self.save_model(output_dir) + if self.deepspeed: + # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed + # config `stage3_gather_fp16_weights_on_model_save` is True + self.deepspeed.save_checkpoint(output_dir) + + # Save optimizer and scheduler + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + self.optimizer.consolidate_state_dict() + + if is_torch_tpu_available(): + xm.rendezvous("saving_optimizer_states") + xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + elif is_sagemaker_mp_enabled(): + if smp.dp_rank() == 0: + # Consolidate the state dict on all processed of dp_rank 0 + opt_state_dict = self.optimizer.state_dict() + # Save it and the scheduler on the main process + if self.args.should_save and not save_only_best: + torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling: + torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + elif self.args.should_save and not self.deepspeed and not save_only_best: + # deepspeed.save_checkpoint above saves model/optim/sched + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + self.do_grad_scaling=True + # if self.do_grad_scaling: + # torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + + # Determine the new best metric / best model checkpoint + if metrics is not None and self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics[metric_to_check] + + operator = np.greater if self.args.greater_is_better else np.less + if ( + self.state.best_metric is None + or self.state.best_model_checkpoint is None + or operator(metric_value, self.state.best_metric) + ): + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + best_dir = os.path.join(run_dir,'best-checkpoint') + # if not os.path.exists(best_dir): + # os.mkdir(best_dir,exist_ok=True) + if self.args.should_save: + self.save_model(best_dir) + self.state.save_to_json(os.path.join(best_dir, TRAINER_STATE_NAME)) + + + # Save the Trainer state + if self.args.should_save and not save_only_best: + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + + # Save RNG state in non-distributed training + rng_states = { + "python": random.getstate(), + "numpy": np.random.get_state(), + "cpu": torch.random.get_rng_state(), + } + if torch.cuda.is_available(): + if self.args.local_rank == -1: + # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) + rng_states["cuda"] = torch.cuda.random.get_rng_state_all() + else: + rng_states["cuda"] = torch.cuda.random.get_rng_state() + + if is_torch_tpu_available(): + rng_states["xla"] = xm.get_rng_state() + + # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may + # not yet exist. + os.makedirs(output_dir, exist_ok=True) + local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank + if local_rank == -1: + torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + else: + torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth")) + + if self.args.push_to_hub: + self._push_from_checkpoint(output_dir) + + # Maybe delete some older checkpoints. + if self.args.should_save and not save_only_best: + self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + + + + + # def predict(self, predict_dataset=None, predict_examples=None, tokenizer=None,ignore_keys=None, metric_key_prefix: str = "predict", + # max_length: Optional[int] = None,num_beams: Optional[int] = None): + # self._memory_tracker.start() + # self._max_length = max_length if max_length is not None else self.args.generation_max_length + # self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams + # predict_dataset = self.predict_dataset if predict_dataset is None else predict_dataset + # predict_dataloader = self.get_eval_dataloader(predict_dataset) + # predict_examples = predict_examples + # + # # Temporarily disable metric computation, we will do it in the loop here. + # compute_metrics = self.compute_metrics + # self.compute_metrics = None + # eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + # try: + # output = eval_loop( + # predict_dataloader, + # description="Predict", + # # No point gathering the predictions if there are no metrics, otherwise we defer to + # # self.args.prediction_loss_only + # prediction_loss_only=True if compute_metrics is None else None, + # ignore_keys=ignore_keys, + # ) + # finally: + # self.compute_metrics = compute_metrics + # + # if self.post_process_function is not None and self.compute_metrics is not None: + # predict_preds = self.post_process_function(predict_examples, predict_dataset, output,tokenizer) + # metrics = self.compute_metrics(predict_preds) + # if self.dataset_name=='narrativeqa': + # metrics = {key: value.mid.fmeasure * 100 for key, value in metrics.items()} + # metrics = {k: round(v, 4) for k, v in metrics.items()} + # # Prefix all keys with metric_key_prefix + '_' + # for key in list(metrics.keys()): + # if not key.startswith(f"{metric_key_prefix}_"): + # metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + # + # self.log(metrics) + # else: + # metrics = {} + # + # self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) + # self._memory_tracker.stop_and_update_metrics(metrics) + # return metrics diff --git a/diana/downstream/metric/accuracy.py b/diana/downstream/metric/accuracy.py new file mode 100644 index 0000000..b9b0e33 --- /dev/null +++ b/diana/downstream/metric/accuracy.py @@ -0,0 +1,105 @@ +# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Accuracy metric.""" + +from sklearn.metrics import accuracy_score + +import datasets + + +_DESCRIPTION = """ +Accuracy is the proportion of correct predictions among the total number of cases processed. It can be computed with: +Accuracy = (TP + TN) / (TP + TN + FP + FN) + Where: +TP: True positive +TN: True negative +FP: False positive +FN: False negative +""" + + +_KWARGS_DESCRIPTION = """ +Args: + predictions (`list` of `int`): Predicted labels. + references (`list` of `int`): Ground truth labels. + normalize (`boolean`): If set to False, returns the number of correctly classified samples. Otherwise, returns the fraction of correctly classified samples. Defaults to True. + sample_weight (`list` of `float`): Sample weights Defaults to None. + +Returns: + accuracy (`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher score means higher accuracy. + +Examples: + + Example 1-A simple example + >>> accuracy_metric = datasets.load_metric("accuracy") + >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0]) + >>> print(results) + {'accuracy': 0.5} + + Example 2-The same as Example 1, except with `normalize` set to `False`. + >>> accuracy_metric = datasets.load_metric("accuracy") + >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], normalize=False) + >>> print(results) + {'accuracy': 3.0} + + Example 3-The same as Example 1, except with `sample_weight` set. + >>> accuracy_metric = datasets.load_metric("accuracy") + >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], sample_weight=[0.5, 2, 0.7, 0.5, 9, 0.4]) + >>> print(results) + {'accuracy': 0.8778625954198473} +""" + + +_CITATION = """ +@article{scikit-learn, + title={Scikit-learn: Machine Learning in {P}ython}, + author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. + and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. + and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and + Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.}, + journal={Journal of Machine Learning Research}, + volume={12}, + pages={2825--2830}, + year={2011} +} +""" + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class Accuracy(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": datasets.Sequence(datasets.Value("int32")), + "references": datasets.Sequence(datasets.Value("int32")), + } + if self.config_name == "multilabel" + else { + "predictions": datasets.Value("int32"), + "references": datasets.Value("int32"), + } + ), + reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"], + ) + + def _compute(self, predictions, references, normalize=True, sample_weight=None): + return { + "accuracy": float( + accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight) + ) + } diff --git a/diana/downstream/metric/bleu.py b/diana/downstream/metric/bleu.py new file mode 100644 index 0000000..545cb1e --- /dev/null +++ b/diana/downstream/metric/bleu.py @@ -0,0 +1,126 @@ +# Copyright 2020 The HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" BLEU metric. """ + +import datasets + +from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py + + +_CITATION = """\ +@INPROCEEDINGS{Papineni02bleu:a, + author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu}, + title = {BLEU: a Method for Automatic Evaluation of Machine Translation}, + booktitle = {}, + year = {2002}, + pages = {311--318} +} +@inproceedings{lin-och-2004-orange, + title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation", + author = "Lin, Chin-Yew and + Och, Franz Josef", + booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics", + month = "aug 23{--}aug 27", + year = "2004", + address = "Geneva, Switzerland", + publisher = "COLING", + url = "https://www.aclweb.org/anthology/C04-1072", + pages = "501--507", +} +""" + +_DESCRIPTION = """\ +BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another. +Quality is considered to be the correspondence between a machine's output and that of a human: "the closer a machine translation is to a professional human translation, +the better it is" – this is the central idea behind BLEU. BLEU was one of the first metrics to claim a high correlation with human judgements of quality, and +remains one of the most popular automated and inexpensive metrics. + +Scores are calculated for individual translated segments—generally sentences—by comparing them with a set of good quality reference translations. +Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. Intelligibility or grammatical correctness +are not taken into account[citation needed]. + +BLEU's output is always a number between 0 and 1. This value indicates how similar the candidate text is to the reference texts, with values closer to 1 +representing more similar texts. Few human translations will attain a score of 1, since this would indicate that the candidate is identical to one of the +reference translations. For this reason, it is not necessary to attain a score of 1. Because there are more opportunities to match, adding additional +reference translations will increase the BLEU score. +""" + +_KWARGS_DESCRIPTION = """ +Computes BLEU score of translated segments against one or more references. +Args: + predictions: list of translations to score. + Each translation should be tokenized into a list of tokens. + references: list of lists of references for each translation. + Each reference should be tokenized into a list of tokens. + max_order: Maximum n-gram order to use when computing BLEU score. + smooth: Whether or not to apply Lin et al. 2004 smoothing. +Returns: + 'bleu': bleu score, + 'precisions': geometric mean of n-gram precisions, + 'brevity_penalty': brevity penalty, + 'length_ratio': ratio of lengths, + 'translation_length': translation_length, + 'reference_length': reference_length +Examples: + + >>> predictions = [ + ... ["hello", "there", "general", "kenobi"], # tokenized prediction of the first sample + ... ["foo", "bar", "foobar"] # tokenized prediction of the second sample + ... ] + >>> references = [ + ... [["hello", "there", "general", "kenobi"], ["hello", "there", "!"]], # tokenized references for the first sample (2 references) + ... [["foo", "bar", "foobar"]] # tokenized references for the second sample (1 reference) + ... ] + >>> bleu = datasets.load_metric("bleu") + >>> results = bleu.compute(predictions=predictions, references=references) + >>> print(results["bleu"]) + 1.0 +""" + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class Bleu(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), + "references": datasets.Sequence( + datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), id="references" + ), + } + ), + codebase_urls=["https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py"], + reference_urls=[ + "https://en.wikipedia.org/wiki/BLEU", + "https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213", + ], + ) + + def _compute(self, predictions, references, max_order=4, smooth=False): + score = compute_bleu( + reference_corpus=references, translation_corpus=predictions, max_order=max_order, smooth=smooth + ) + (bleu, precisions, bp, ratio, translation_length, reference_length) = score + return { + "bleu": bleu, + "precisions": precisions, + "brevity_penalty": bp, + "length_ratio": ratio, + "translation_length": translation_length, + "reference_length": reference_length, + } diff --git a/diana/downstream/metric/bleu_local/bleu.py b/diana/downstream/metric/bleu_local/bleu.py new file mode 100644 index 0000000..03eda55 --- /dev/null +++ b/diana/downstream/metric/bleu_local/bleu.py @@ -0,0 +1,219 @@ +# Copyright 2020 The HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" BLEU metric. """ + +import datasets + +#from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py + + +_CITATION = """\ +@INPROCEEDINGS{Papineni02bleu:a, + author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu}, + title = {BLEU: a Method for Automatic Evaluation of Machine Translation}, + booktitle = {}, + year = {2002}, + pages = {311--318} +} +@inproceedings{lin-och-2004-orange, + title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation", + author = "Lin, Chin-Yew and + Och, Franz Josef", + booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics", + month = "aug 23{--}aug 27", + year = "2004", + address = "Geneva, Switzerland", + publisher = "COLING", + url = "https://www.aclweb.org/anthology/C04-1072", + pages = "501--507", +} +""" + +_DESCRIPTION = """\ +BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another. +Quality is considered to be the correspondence between a machine's output and that of a human: "the closer a machine translation is to a professional human translation, +the better it is" – this is the central idea behind BLEU. BLEU was one of the first metrics to claim a high correlation with human judgements of quality, and +remains one of the most popular automated and inexpensive metrics. + +Scores are calculated for individual translated segments—generally sentences—by comparing them with a set of good quality reference translations. +Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. Intelligibility or grammatical correctness +are not taken into account[citation needed]. + +BLEU's output is always a number between 0 and 1. This value indicates how similar the candidate text is to the reference texts, with values closer to 1 +representing more similar texts. Few human translations will attain a score of 1, since this would indicate that the candidate is identical to one of the +reference translations. For this reason, it is not necessary to attain a score of 1. Because there are more opportunities to match, adding additional +reference translations will increase the BLEU score. +""" + +_KWARGS_DESCRIPTION = """ +Computes BLEU score of translated segments against one or more references. +Args: + predictions: list of translations to score. + Each translation should be tokenized into a list of tokens. + references: list of lists of references for each translation. + Each reference should be tokenized into a list of tokens. + max_order: Maximum n-gram order to use when computing BLEU score. + smooth: Whether or not to apply Lin et al. 2004 smoothing. +Returns: + 'bleu': bleu score, + 'precisions': geometric mean of n-gram precisions, + 'brevity_penalty': brevity penalty, + 'length_ratio': ratio of lengths, + 'translation_length': translation_length, + 'reference_length': reference_length +Examples: + + >>> predictions = [ + ... ["hello", "there", "general", "kenobi"], # tokenized prediction of the first sample + ... ["foo", "bar", "foobar"] # tokenized prediction of the second sample + ... ] + >>> references = [ + ... [["hello", "there", "general", "kenobi"], ["hello", "there", "!"]], # tokenized references for the first sample (2 references) + ... [["foo", "bar", "foobar"]] # tokenized references for the second sample (1 reference) + ... ] + >>> bleu = datasets.load_metric("bleu") + >>> results = bleu.compute(predictions=predictions, references=references) + >>> print(results["bleu"]) + 1.0 +""" + +import collections +import math + + +def _get_ngrams(segment, max_order): + """Extracts all n-grams upto a given maximum order from an input segment. + Args: + segment: text segment from which n-grams will be extracted. + max_order: maximum length in tokens of the n-grams returned by this + methods. + Returns: + The Counter containing all n-grams upto max_order in segment + with a count of how many times each n-gram occurred. + """ + ngram_counts = collections.Counter() + for order in range(1, max_order + 1): + for i in range(0, len(segment) - order + 1): + ngram = tuple(segment[i:i+order]) + ngram_counts[ngram] += 1 + return ngram_counts + + +def compute_bleu(reference_corpus, translation_corpus, max_order=1, + smooth=False): + """Computes BLEU score of translated segments against one or more references. + Args: + reference_corpus: list of lists of references for each translation. Each + reference should be tokenized into a list of tokens. + translation_corpus: list of translations to score. Each translation + should be tokenized into a list of tokens. + max_order: Maximum n-gram order to use when computing BLEU score. + smooth: Whether or not to apply Lin et al. 2004 smoothing. + Returns: + 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram + precisions and brevity penalty. + """ + matches_by_order = [0] * max_order + possible_matches_by_order = [0] * max_order + reference_length = 0 + translation_length = 0 + for (references, translation) in zip(reference_corpus, + translation_corpus): + # translation = references[0] + translation = [item.lower() for item in translation] + + reference_length += min(len(r) for r in references) + translation_length += len(translation) + + merged_ref_ngram_counts = collections.Counter() + + for reference in references: + merged_ref_ngram_counts |= _get_ngrams(reference, max_order) + # print(merged_ref_ngram_counts) + translation_ngram_counts = _get_ngrams(translation, max_order) + overlap = translation_ngram_counts & merged_ref_ngram_counts + for ngram in overlap: + matches_by_order[len(ngram)-1] += overlap[ngram] + for order in range(1, max_order+1): + possible_matches = len(translation) - order + 1 + if possible_matches > 0: + possible_matches_by_order[order-1] += possible_matches + + + precisions = [0] * max_order + for i in range(0, max_order): + if smooth: + precisions[i] = ((matches_by_order[i] + 1.) / + (possible_matches_by_order[i] + 1.)) + else: + if possible_matches_by_order[i] > 0: + precisions[i] = (float(matches_by_order[i]) / + possible_matches_by_order[i]) + else: + precisions[i] = 0.0 + + + if min(precisions) > 0: + p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) + geo_mean = math.exp(p_log_sum) + else: + geo_mean = 0 + + ratio = float(translation_length) / reference_length + + if ratio > 1.0: + bp = 1. + else: + bp = math.exp(1 - 1. / ratio) + + bleu = geo_mean * bp + + + return (bleu, precisions, bp, ratio, translation_length, reference_length) + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class Bleu(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), + "references": datasets.Sequence( + datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), id="references" + ), + } + ), + codebase_urls=["https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py"], + reference_urls=[ + "https://en.wikipedia.org/wiki/BLEU", + "https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213", + ], + ) + + def _compute(self, predictions, references, max_order=4, smooth=False): + score = compute_bleu( + reference_corpus=references, translation_corpus=predictions, max_order=1, smooth=smooth + ) + (bleu, precisions, bp, ratio, translation_length, reference_length) = score + return { + "bleu": bleu, + "precisions": precisions, + "brevity_penalty": bp, + "length_ratio": ratio, + "translation_length": translation_length, + "reference_length": reference_length, + } diff --git a/diana/downstream/metric/rouge_local/rouge_metric.py b/diana/downstream/metric/rouge_local/rouge_metric.py new file mode 100644 index 0000000..66eb7f1 --- /dev/null +++ b/diana/downstream/metric/rouge_local/rouge_metric.py @@ -0,0 +1,133 @@ +# Copyright 2020 The HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" ROUGE metric from Google Research github repo. """ + +# The dependencies in https://github.com/google-research/google-research/blob/master/rouge/requirements.txt +import absl # Here to have a nice missing dependency error message early on +import nltk # Here to have a nice missing dependency error message early on +import numpy # Here to have a nice missing dependency error message early on +import six # Here to have a nice missing dependency error message early on +from .rouge_score import RougeScorer +from .scoring import * +from .rouge_tokenizers import * +import datasets + + +_CITATION = """\ +@inproceedings{lin-2004-rouge, + title = "{ROUGE}: A Package for Automatic Evaluation of Summaries", + author = "Lin, Chin-Yew", + booktitle = "Text Summarization Branches Out", + month = jul, + year = "2004", + address = "Barcelona, Spain", + publisher = "Association for Computational Linguistics", + url = "https://www.aclweb.org/anthology/W04-1013", + pages = "74--81", +} +""" + +_DESCRIPTION = """\ +ROUGE, or Recall-Oriented Understudy for Gisting Evaluation, is a set of metrics and a software package used for +evaluating automatic summarization and machine translation software in natural language processing. +The metrics compare an automatically produced summary or translation against a reference or a set of references (human-produced) summary or translation. + +Note that ROUGE is case insensitive, meaning that upper case letters are treated the same way as lower case letters. + +This metrics is a wrapper around Google Research reimplementation of ROUGE: +https://github.com/google-research/google-research/tree/master/rouge +""" + +_KWARGS_DESCRIPTION = """ +Calculates average rouge scores for a list of hypotheses and references +Args: + predictions: list of predictions to score. Each prediction + should be a string with tokens separated by spaces. + references: list of reference for each prediction. Each + reference should be a string with tokens separated by spaces. + rouge_types: A list of rouge types to calculate. + Valid names: + `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring, + `"rougeL"`: Longest common subsequence based scoring. + `"rougeLSum"`: rougeLsum splits text using `"\n"`. + See details in https://github.com/huggingface/datasets/issues/617 + use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes. + use_aggregator: Return aggregates if this is set to True +Returns: + rouge1: rouge_1 (precision, recall, f1), + rouge2: rouge_2 (precision, recall, f1), + rougeL: rouge_l (precision, recall, f1), + rougeLsum: rouge_lsum (precision, recall, f1) +Examples: + + >>> rouge = datasets.load_metric('rouge') + >>> predictions = ["hello there", "general kenobi"] + >>> references = ["hello there", "general kenobi"] + >>> results = rouge.compute(predictions=predictions, references=references) + >>> print(list(results.keys())) + ['rouge1', 'rouge2', 'rougeL', 'rougeLsum'] + >>> print(results["rouge1"]) + AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)) + >>> print(results["rouge1"].mid.fmeasure) + 1.0 +""" + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class Rouge(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": datasets.Value("string", id="sequence"), + "references": datasets.Value("string", id="sequence"), + } + ), + codebase_urls=["https://github.com/google-research/google-research/tree/master/rouge"], + reference_urls=[ + "https://en.wikipedia.org/wiki/ROUGE_(metric)", + "https://github.com/google-research/google-research/tree/master/rouge", + ], + ) + + def _compute(self, predictions, references, rouge_types=None, use_aggregator=True, use_stemmer=False): + if rouge_types is None: + rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"] + + scorer = RougeScorer(rouge_types=rouge_types, use_stemmer=use_stemmer) + if use_aggregator: + aggregator = BootstrapAggregator() + else: + scores = [] + + for ref, pred in zip(references, predictions): + score = scorer.score(ref, pred) + if use_aggregator: + aggregator.add_scores(score) + else: + scores.append(score) + + if use_aggregator: + result = aggregator.aggregate() + for key in rouge_types: + result[key] = result[key].mid.fmeasure + else: + result = {} + for key in scores[0]: + result[key] = list(score[key] for score in scores) + + return result diff --git a/diana/downstream/metric/rouge_local/rouge_score.py b/diana/downstream/metric/rouge_local/rouge_score.py new file mode 100644 index 0000000..eaed59b --- /dev/null +++ b/diana/downstream/metric/rouge_local/rouge_score.py @@ -0,0 +1,318 @@ +# coding=utf-8 +# Copyright 2022 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Computes rouge scores between two text blobs. +Implementation replicates the functionality in the original ROUGE package. See: +Lin, Chin-Yew. ROUGE: a Package for Automatic Evaluation of Summaries. In +Proceedings of the Workshop on Text Summarization Branches Out (WAS 2004), +Barcelona, Spain, July 25 - 26, 2004. +Default options are equivalent to running: +ROUGE-1.5.5.pl -e data -n 2 -a settings.xml +Or with use_stemmer=True: +ROUGE-1.5.5.pl -m -e data -n 2 -a settings.xml +In these examples settings.xml lists input files and formats. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re + +from absl import logging +import nltk +import numpy as np +import six +from six.moves import map +from six.moves import range +from .scoring import * +from .rouge_tokenizers import * + + + +class RougeScorer(BaseScorer): + """Calculate rouges scores between two blobs of text. + Sample usage: + scorer = RougeScorer(['rouge1', 'rougeL'], use_stemmer=True) + scores = scorer.score('The quick brown fox jumps over the lazy dog', + 'The quick brown dog jumps on the log.') + """ + + def __init__(self, rouge_types, use_stemmer=False, split_summaries=False, + tokenizer=None): + """Initializes a new RougeScorer. + Valid rouge types that can be computed are: + rougen (e.g. rouge1, rouge2): n-gram based scoring. + rougeL: Longest common subsequence based scoring. + Args: + rouge_types: A list of rouge types to calculate. + use_stemmer: Bool indicating whether Porter stemmer should be used to + strip word suffixes to improve matching. This arg is used in the + DefaultTokenizer, but other tokenizers might or might not choose to + use this. + split_summaries: whether to add newlines between sentences for rougeLsum + tokenizer: Tokenizer object which has a tokenize() method. + Returns: + A dict mapping rouge types to Score tuples. + """ + + self.rouge_types = rouge_types + if tokenizer: + self._tokenizer = tokenizer + else: + self._tokenizer = DefaultTokenizer(use_stemmer) + logging.info("Using default tokenizer.") + + self._split_summaries = split_summaries + + def score_multi(self, targets, prediction): + """Calculates rouge scores between targets and prediction. + The target with the maximum f-measure is used for the final score for + each score type.. + Args: + targets: list of texts containing the targets + prediction: Text containing the predicted text. + Returns: + A dict mapping each rouge type to a Score object. + Raises: + ValueError: If an invalid rouge type is encountered. + """ + score_dicts = [self.score(t, prediction) for t in targets] + max_score = {} + for k in self.rouge_types: + index = np.argmax([s[k].fmeasure for s in score_dicts]) + max_score[k] = score_dicts[index][k] + + return max_score + + def score(self, target, prediction): + """Calculates rouge scores between the target and prediction. + Args: + target: Text containing the target (ground truth) text, + or if a list + prediction: Text containing the predicted text. + Returns: + A dict mapping each rouge type to a Score object. + Raises: + ValueError: If an invalid rouge type is encountered. + """ + + # Pre-compute target tokens and prediction tokens for use by different + # types, except if only "rougeLsum" is requested. + if len(self.rouge_types) == 1 and self.rouge_types[0] == "rougeLsum": + target_tokens = None + prediction_tokens = None + else: + target_tokens = self._tokenizer.tokenize(target) + prediction_tokens = self._tokenizer.tokenize(prediction) + result = {} + + for rouge_type in self.rouge_types: + if rouge_type == "rougeL": + # Rouge from longest common subsequences. + scores = _score_lcs(target_tokens, prediction_tokens) + elif rouge_type == "rougeLsum": + # Note: Does not support multi-line text. + def get_sents(text): + if self._split_summaries: + sents = nltk.sent_tokenize(text) + else: + # Assume sentences are separated by newline. + sents = six.ensure_str(text).split("\n") + sents = [x for x in sents if len(x)] + return sents + + target_tokens_list = [ + self._tokenizer.tokenize(s) for s in get_sents(target)] + prediction_tokens_list = [ + self._tokenizer.tokenize(s) for s in get_sents(prediction)] + + scores = _summary_level_lcs(target_tokens_list, + prediction_tokens_list) + elif re.match(r"rouge[0-9]$", six.ensure_str(rouge_type)): + # Rouge from n-grams. + n = int(rouge_type[5:]) + if n <= 0: + raise ValueError("rougen requires positive n: %s" % rouge_type) + target_ngrams = _create_ngrams(target_tokens, n) + prediction_ngrams = _create_ngrams(prediction_tokens, n) + scores = _score_ngrams(target_ngrams, prediction_ngrams) + else: + raise ValueError("Invalid rouge type: %s" % rouge_type) + result[rouge_type] = scores + + return result + + +def _create_ngrams(tokens, n): + """Creates ngrams from the given list of tokens. + Args: + tokens: A list of tokens from which ngrams are created. + n: Number of tokens to use, e.g. 2 for bigrams. + Returns: + A dictionary mapping each bigram to the number of occurrences. + """ + + ngrams = collections.Counter() + for ngram in (tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)): + ngrams[ngram] += 1 + return ngrams + + +def _score_lcs(target_tokens, prediction_tokens): + """Computes LCS (Longest Common Subsequence) rouge scores. + Args: + target_tokens: Tokens from the target text. + prediction_tokens: Tokens from the predicted text. + Returns: + A Score object containing computed scores. + """ + + if not target_tokens or not prediction_tokens: + return Score(precision=0, recall=0, fmeasure=0) + + # Compute length of LCS from the bottom up in a table (DP appproach). + lcs_table = _lcs_table(target_tokens, prediction_tokens) + lcs_length = lcs_table[-1][-1] + + precision = lcs_length / len(prediction_tokens) + recall = lcs_length / len(target_tokens) + fmeasure = func_fmeasure(precision, recall) + + return Score(precision=precision, recall=recall, fmeasure=fmeasure) + + +def _lcs_table(ref, can): + """Create 2-d LCS score table.""" + rows = len(ref) + cols = len(can) + lcs_table = [[0] * (cols + 1) for _ in range(rows + 1)] + for i in range(1, rows + 1): + for j in range(1, cols + 1): + if ref[i - 1] == can[j - 1]: + lcs_table[i][j] = lcs_table[i - 1][j - 1] + 1 + else: + lcs_table[i][j] = max(lcs_table[i - 1][j], lcs_table[i][j - 1]) + return lcs_table + + +def _backtrack_norec(t, ref, can): + """Read out LCS.""" + i = len(ref) + j = len(can) + lcs = [] + while i > 0 and j > 0: + if ref[i - 1] == can[j - 1]: + lcs.insert(0, i-1) + i -= 1 + j -= 1 + elif t[i][j - 1] > t[i - 1][j]: + j -= 1 + else: + i -= 1 + return lcs + + +def _summary_level_lcs(ref_sent, can_sent): + """ROUGE: Summary-level LCS, section 3.2 in ROUGE paper. + Args: + ref_sent: list of tokenized reference sentences + can_sent: list of tokenized candidate sentences + Returns: + summary level ROUGE score + """ + if not ref_sent or not can_sent: + return Score(precision=0, recall=0, fmeasure=0) + + m = sum(map(len, ref_sent)) + n = sum(map(len, can_sent)) + if not n or not m: + return Score(precision=0, recall=0, fmeasure=0) + + # get token counts to prevent double counting + token_cnts_r = collections.Counter() + token_cnts_c = collections.Counter() + for s in ref_sent: + # s is a list of tokens + token_cnts_r.update(s) + for s in can_sent: + token_cnts_c.update(s) + + hits = 0 + for r in ref_sent: + lcs = _union_lcs(r, can_sent) + # Prevent double-counting: + # The paper describes just computing hits += len(_union_lcs()), + # but the implementation prevents double counting. We also + # implement this as in version 1.5.5. + for t in lcs: + if token_cnts_c[t] > 0 and token_cnts_r[t] > 0: + hits += 1 + token_cnts_c[t] -= 1 + token_cnts_r[t] -= 1 + + recall = hits / m + precision = hits / n + fmeasure = func_fmeasure(precision, recall) + return Score(precision=precision, recall=recall, fmeasure=fmeasure) + + +def _union_lcs(ref, c_list): + """Find union LCS between a ref sentence and list of candidate sentences. + Args: + ref: list of tokens + c_list: list of list of indices for LCS into reference summary + Returns: + List of tokens in ref representing union LCS. + """ + lcs_list = [lcs_ind(ref, c) for c in c_list] + return [ref[i] for i in _find_union(lcs_list)] + + +def _find_union(lcs_list): + """Finds union LCS given a list of LCS.""" + return sorted(list(set().union(*lcs_list))) + + +def lcs_ind(ref, can): + """Returns one of the longest lcs.""" + t = _lcs_table(ref, can) + return _backtrack_norec(t, ref, can) + + +def _score_ngrams(target_ngrams, prediction_ngrams): + """Compute n-gram based rouge scores. + Args: + target_ngrams: A Counter object mapping each ngram to number of + occurrences for the target text. + prediction_ngrams: A Counter object mapping each ngram to number of + occurrences for the prediction text. + Returns: + A Score object containing computed scores. + """ + + intersection_ngrams_count = 0 + for ngram in six.iterkeys(target_ngrams): + intersection_ngrams_count += min(target_ngrams[ngram], + prediction_ngrams[ngram]) + target_ngrams_count = sum(target_ngrams.values()) + prediction_ngrams_count = sum(prediction_ngrams.values()) + + precision = intersection_ngrams_count / max(prediction_ngrams_count, 1) + recall = intersection_ngrams_count / max(target_ngrams_count, 1) + fmeasure = func_fmeasure(precision, recall) + + return Score(precision=precision, recall=recall, fmeasure=fmeasure) \ No newline at end of file diff --git a/diana/downstream/metric/rouge_local/rouge_tokenizers.py b/diana/downstream/metric/rouge_local/rouge_tokenizers.py new file mode 100644 index 0000000..91e835b --- /dev/null +++ b/diana/downstream/metric/rouge_local/rouge_tokenizers.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright 2022 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +"""Library containing Tokenizer definitions. +The RougeScorer class can be instantiated with the tokenizers defined here. New +tokenizers can be defined by creating a subclass of the Tokenizer abstract class +and overriding the tokenize() method. +""" +import abc +from nltk.stem import porter + + +import re +import six + + +# Pre-compile regexes that are use often +NON_ALPHANUM_PATTERN = r"[^a-z0-9]+" +NON_ALPHANUM_RE = re.compile(NON_ALPHANUM_PATTERN) +SPACES_PATTERN = r"\s+" +SPACES_RE = re.compile(SPACES_PATTERN) +VALID_TOKEN_PATTERN = r"^[a-z0-9]+$" +VALID_TOKEN_RE = re.compile(VALID_TOKEN_PATTERN) + + +def _tokenize(text, stemmer): + """Tokenize input text into a list of tokens. + This approach aims to replicate the approach taken by Chin-Yew Lin in + the original ROUGE implementation. + Args: + text: A text blob to tokenize. + stemmer: An optional stemmer. + Returns: + A list of string tokens extracted from input text. + """ + + # Convert everything to lowercase. + text = text.lower() + # Replace any non-alpha-numeric characters with spaces. + text = NON_ALPHANUM_RE.sub(" ", six.ensure_str(text)) + + tokens = SPACES_RE.split(text) + if stemmer: + # Only stem words more than 3 characters long. + tokens = [six.ensure_str(stemmer.stem(x)) if len(x) > 3 else x + for x in tokens] + + # One final check to drop any empty or invalid tokens. + tokens = [x for x in tokens if VALID_TOKEN_RE.match(x)] + + return tokens + + +class Tokenizer(abc.ABC): + """Abstract base class for a tokenizer. + Subclasses of Tokenizer must implement the tokenize() method. + """ + + @abc.abstractmethod + def tokenize(self, text): + raise NotImplementedError("Tokenizer must override tokenize() method") + + +class DefaultTokenizer(Tokenizer): + """Default tokenizer which tokenizes on whitespace.""" + + def __init__(self, use_stemmer=False): + """Constructor for DefaultTokenizer. + Args: + use_stemmer: boolean, indicating whether Porter stemmer should be used to + strip word suffixes to improve matching. + """ + self._stemmer = porter.PorterStemmer() if use_stemmer else None + + def tokenize(self, text): + return _tokenize(text, self._stemmer) \ No newline at end of file diff --git a/diana/downstream/metric/rouge_local/scoring.py b/diana/downstream/metric/rouge_local/scoring.py new file mode 100644 index 0000000..0a924ae --- /dev/null +++ b/diana/downstream/metric/rouge_local/scoring.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright 2022 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library for scoring and evaluation of text samples. +Aggregation functions use bootstrap resampling to compute confidence intervals +as per the original ROUGE perl implementation. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import collections +from typing import Dict + +import numpy as np +import six +from six.moves import range + + +class Score( + collections.namedtuple("Score", ["precision", "recall", "fmeasure"])): + """Tuple containing precision, recall, and f-measure values.""" + + +class BaseScorer(object, metaclass=abc.ABCMeta): + """Base class for Scorer objects.""" + + @abc.abstractmethod + def score(self, target, prediction): + """Calculates score between the target and prediction. + Args: + target: Text containing the target (ground truth) text. + prediction: Text containing the predicted text. + Returns: + A dict mapping each score_type (string) to Score object. + """ + + +class AggregateScore( + collections.namedtuple("AggregateScore", ["low", "mid", "high"])): + """Tuple containing confidence intervals for scores.""" + + +class BootstrapAggregator(object): + """Aggregates scores to provide confidence intervals. + Sample usage: + scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL']) + aggregator = Aggregator() + aggregator.add_scores(scorer.score("one two three", "one two")) + aggregator.add_scores(scorer.score("one two five six", "seven eight")) + result = aggregator.aggregate() + print result + {'rougeL': AggregateScore( + low=Score(precision=0.0, recall=0.0, fmeasure=0.0), + mid=Score(precision=0.5, recall=0.33, fmeasure=0.40), + high=Score(precision=1.0, recall=0.66, fmeasure=0.80)), + 'rouge1': AggregateScore( + low=Score(precision=0.0, recall=0.0, fmeasure=0.0), + mid=Score(precision=0.5, recall=0.33, fmeasure=0.40), + high=Score(precision=1.0, recall=0.66, fmeasure=0.80))} + """ + + def __init__(self, confidence_interval=0.95, n_samples=1000): + """Initializes a BootstrapAggregator object. + Args: + confidence_interval: Confidence interval to compute on the mean as a + decimal. + n_samples: Number of samples to use for bootstrap resampling. + Raises: + ValueError: If invalid argument is given. + """ + + if confidence_interval < 0 or confidence_interval > 1: + raise ValueError("confidence_interval must be in range [0, 1]") + if n_samples <= 0: + raise ValueError("n_samples must be positive") + + self._n_samples = n_samples + self._confidence_interval = confidence_interval + self._scores = collections.defaultdict(list) + + def add_scores(self, scores): + """Adds a sample for future aggregation. + Args: + scores: Dict mapping score_type strings to a namedtuple object/class + representing a score. + """ + + for score_type, score in six.iteritems(scores): + self._scores[score_type].append(score) + + def aggregate(self): + """Aggregates scores previously added using add_scores. + Returns: + A dict mapping score_type to AggregateScore objects. + """ + + result = {} + for score_type, scores in six.iteritems(self._scores): + # Stack scores into a 2-d matrix of (sample, measure). + score_matrix = np.vstack(tuple(scores)) + # Percentiles are returned as (interval, measure). + percentiles = self._bootstrap_resample(score_matrix) + # Extract the three intervals (low, mid, high). + intervals = tuple( + (scores[0].__class__(*percentiles[j, :]) for j in range(3))) + result[score_type] = AggregateScore( + low=intervals[0], mid=intervals[1], high=intervals[2]) + return result + + def _bootstrap_resample(self, matrix): + """Performs bootstrap resampling on a matrix of scores. + Args: + matrix: A 2-d matrix of (sample, measure). + Returns: + A 2-d matrix of (bounds, measure). There are three bounds: low (row 0), + mid (row 1) and high (row 2). Mid is always the mean, while low and high + bounds are specified by self._confidence_interval (which defaults to 0.95 + meaning it will return the 2.5th and 97.5th percentiles for a 95% + confidence interval on the mean). + """ + + # Matrix of (bootstrap sample, measure). + sample_mean = np.zeros((self._n_samples, matrix.shape[1])) + for i in range(self._n_samples): + sample_idx = np.random.choice( + np.arange(matrix.shape[0]), size=matrix.shape[0]) + sample = matrix[sample_idx, :] + sample_mean[i, :] = np.mean(sample, axis=0) + + # Take percentiles on the estimate of the mean using bootstrap samples. + # Final result is a (bounds, measure) matrix. + percentile_delta = (1 - self._confidence_interval) / 2 + q = 100 * np.array([percentile_delta, 0.5, 1 - percentile_delta]) + return np.percentile(sample_mean, q, axis=0) + + +def func_fmeasure(precision, recall): + """Computes f-measure given precision and recall values.""" + + if precision + recall > 0: + return 2 * precision * recall / (precision + recall) + else: + return 0.0 \ No newline at end of file diff --git a/diana/downstream/metric/squad_v1_local/evaluate.py b/diana/downstream/metric/squad_v1_local/evaluate.py new file mode 100644 index 0000000..41effa7 --- /dev/null +++ b/diana/downstream/metric/squad_v1_local/evaluate.py @@ -0,0 +1,92 @@ +""" Official evaluation script for v1.1 of the SQuAD dataset. """ + +import argparse +import json +import re +import string +import sys +from collections import Counter + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match_score(prediction, ground_truth): + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def evaluate(dataset, predictions): + f1 = exact_match = total = 0 + for article in dataset: + for paragraph in article["paragraphs"]: + for qa in paragraph["qas"]: + total += 1 + if qa["id"] not in predictions: + message = "Unanswered question " + qa["id"] + " will receive score 0." + print(message, file=sys.stderr) + continue + ground_truths = list(map(lambda x: x["text"], qa["answers"])) + prediction = predictions[qa["id"]] + exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) + f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths) + + exact_match = 100.0 * exact_match / total + f1 = 100.0 * f1 / total + + return {"exact_match": exact_match, "f1": f1} + + +if __name__ == "__main__": + expected_version = "1.1" + parser = argparse.ArgumentParser(description="Evaluation for SQuAD " + expected_version) + parser.add_argument("dataset_file", help="Dataset file") + parser.add_argument("prediction_file", help="Prediction File") + args = parser.parse_args() + with open(args.dataset_file) as dataset_file: + dataset_json = json.load(dataset_file) + if dataset_json["version"] != expected_version: + print( + "Evaluation expects v-" + expected_version + ", but got dataset with v-" + dataset_json["version"], + file=sys.stderr, + ) + dataset = dataset_json["data"] + with open(args.prediction_file) as prediction_file: + predictions = json.load(prediction_file) + print(json.dumps(evaluate(dataset, predictions))) diff --git a/diana/downstream/metric/squad_v1_local/squad_v1_local.py b/diana/downstream/metric/squad_v1_local/squad_v1_local.py new file mode 100644 index 0000000..1456f63 --- /dev/null +++ b/diana/downstream/metric/squad_v1_local/squad_v1_local.py @@ -0,0 +1,109 @@ +# Copyright 2020 The HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" SQuAD metric. """ + +import datasets + +from .evaluate import evaluate + + +_CITATION = """\ +@inproceedings{Rajpurkar2016SQuAD10, + title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text}, + author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang}, + booktitle={EMNLP}, + year={2016} +} +""" + +_DESCRIPTION = """ +This metric wrap the official scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD). + +Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by +crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, +from the corresponding reading passage, or the question might be unanswerable. +""" + +_KWARGS_DESCRIPTION = """ +Computes SQuAD scores (F1 and EM). +Args: + predictions: List of question-answers dictionaries with the following key-values: + - 'id': id of the question-answer pair as given in the references (see below) + - 'prediction_text': the text of the answer + references: List of question-answers dictionaries with the following key-values: + - 'id': id of the question-answer pair (see above), + - 'answers': a Dict in the SQuAD dataset format + { + 'text': list of possible texts for the answer, as a list of strings + 'answer_start': list of start positions for the answer, as a list of ints + } + Note that answer_start values are not taken into account to compute the metric. +Returns: + 'exact_match': Exact match (the normalized answer exactly match the gold answer) + 'f1': The F-score of predicted tokens versus the gold answer +Examples: + + >>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22'}] + >>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}] + >>> squad_metric = datasets.load_metric("squad") + >>> results = squad_metric.compute(predictions=predictions, references=references) + >>> print(results) + {'exact_match': 100.0, 'f1': 100.0} +""" + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class Squad(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": {"id": datasets.Value("string"), "prediction_text": datasets.Value("string")}, + "references": { + "id": datasets.Value("string"), + "answers": datasets.features.Sequence( + { + "text": datasets.Value("string"), + "answer_start": datasets.Value("int32"), + } + ), + }, + } + ), + codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], + reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], + ) + + def _compute(self, predictions, references): + pred_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions} + dataset = [ + { + "paragraphs": [ + { + "qas": [ + { + "answers": [{"text": answer_text} for answer_text in ref["answers"]["text"]], + "id": ref["id"], + } + for ref in references + ] + } + ] + } + ] + score = evaluate(dataset=dataset, predictions=pred_dict) + return score diff --git a/diana/downstream/metric/squad_v2_local/evaluate.py b/diana/downstream/metric/squad_v2_local/evaluate.py new file mode 100644 index 0000000..1a04e53 --- /dev/null +++ b/diana/downstream/metric/squad_v2_local/evaluate.py @@ -0,0 +1,330 @@ +"""Official evaluation script for SQuAD version 2.0. + +In addition to basic functionality, we also compute additional statistics and +plot precision-recall curves if an additional na_prob.json file is provided. +This file is expected to map question ID's to the model's predicted probability +that a question is unanswerable. +""" +import argparse +import collections +import json +import os +import re +import string +import sys + +import numpy as np + + +OPTS = None + + +def parse_args(): + parser = argparse.ArgumentParser("Official evaluation script for SQuAD version 2.0.") + parser.add_argument("data_file", metavar="data.json", help="Input data JSON file.") + parser.add_argument("pred_file", metavar="pred.json", help="Model predictions.") + parser.add_argument( + "--out-file", "-o", metavar="eval.json", help="Write accuracy metrics to file (default is stdout)." + ) + parser.add_argument( + "--na-prob-file", "-n", metavar="na_prob.json", help="Model estimates of probability of no answer." + ) + parser.add_argument( + "--na-prob-thresh", + "-t", + type=float, + default=1.0, + help='Predict "" if no-answer probability exceeds this (default = 1.0).', + ) + parser.add_argument( + "--out-image-dir", "-p", metavar="out_images", default=None, help="Save precision-recall curves to directory." + ) + parser.add_argument("--verbose", "-v", action="store_true") + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + return parser.parse_args() + + +def make_qid_to_has_ans(dataset): + qid_to_has_ans = {} + for article in dataset: + for p in article["paragraphs"]: + for qa in p["qas"]: + qid_to_has_ans[qa["id"]] = bool(qa["answers"]["text"]) + return qid_to_has_ans + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) + return re.sub(regex, " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def get_tokens(s): + if not s: + return [] + return normalize_answer(s).split() + + +def compute_exact(a_gold, a_pred): + return int(normalize_answer(a_gold) == normalize_answer(a_pred)) + + +def compute_f1(a_gold, a_pred): + gold_toks = get_tokens(a_gold) + pred_toks = get_tokens(a_pred) + common = collections.Counter(gold_toks) & collections.Counter(pred_toks) + num_same = sum(common.values()) + if len(gold_toks) == 0 or len(pred_toks) == 0: + # If either is no-answer, then F1 is 1 if they agree, 0 otherwise + return int(gold_toks == pred_toks) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(pred_toks) + recall = 1.0 * num_same / len(gold_toks) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def get_raw_scores(dataset, preds): + exact_scores = {} + f1_scores = {} + for article in dataset: + for p in article["paragraphs"]: + for qa in p["qas"]: + qid = qa["id"] + if qid not in preds: + print(f"Missing prediction for {qid}") + continue + a_pred = preds[qid] + + gold_answers = [t for t in qa["answers"]["text"] if normalize_answer(t)] + if not gold_answers: + # For unanswerable questions, only correct answer is empty string + gold_answers = [""] + if a_pred != "": + exact_scores[qid] = 0 + f1_scores[qid] = 0 + else: + exact_scores[qid] = 1 + f1_scores[qid] = 1 + else: + # Take max over all gold answers + exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) + f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) + return exact_scores, f1_scores + + +def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): + new_scores = {} + for qid, s in scores.items(): + pred_na = na_probs[qid] > na_prob_thresh + if pred_na: + new_scores[qid] = float(not qid_to_has_ans[qid]) + else: + new_scores[qid] = s + return new_scores + + +def make_eval_dict(exact_scores, f1_scores, qid_list=None): + if not qid_list: + total = len(exact_scores) + return collections.OrderedDict( + [ + ("exact", 100.0 * sum(exact_scores.values()) / total), + ("f1", 100.0 * sum(f1_scores.values()) / total), + ("total", total), + ] + ) + else: + total = len(qid_list) + return collections.OrderedDict( + [ + ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total), + ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total), + ("total", total), + ] + ) + + +def merge_eval(main_eval, new_eval, prefix): + for k in new_eval: + main_eval[f"{prefix}_{k}"] = new_eval[k] + + +def plot_pr_curve(precisions, recalls, out_image, title): + plt.step(recalls, precisions, color="b", alpha=0.2, where="post") + plt.fill_between(recalls, precisions, step="post", alpha=0.2, color="b") + plt.xlabel("Recall") + plt.ylabel("Precision") + plt.xlim([0.0, 1.05]) + plt.ylim([0.0, 1.05]) + plt.title(title) + plt.savefig(out_image) + plt.clf() + + +def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None): + qid_list = sorted(na_probs, key=lambda k: na_probs[k]) + true_pos = 0.0 + cur_p = 1.0 + cur_r = 0.0 + precisions = [1.0] + recalls = [0.0] + avg_prec = 0.0 + for i, qid in enumerate(qid_list): + if qid_to_has_ans[qid]: + true_pos += scores[qid] + cur_p = true_pos / float(i + 1) + cur_r = true_pos / float(num_true_pos) + if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]: + # i.e., if we can put a threshold after this point + avg_prec += cur_p * (cur_r - recalls[-1]) + precisions.append(cur_p) + recalls.append(cur_r) + if out_image: + plot_pr_curve(precisions, recalls, out_image, title) + return {"ap": 100.0 * avg_prec} + + +def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, out_image_dir): + if out_image_dir and not os.path.exists(out_image_dir): + os.makedirs(out_image_dir) + num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) + if num_true_pos == 0: + return + pr_exact = make_precision_recall_eval( + exact_raw, + na_probs, + num_true_pos, + qid_to_has_ans, + out_image=os.path.join(out_image_dir, "pr_exact.png"), + title="Precision-Recall curve for Exact Match score", + ) + pr_f1 = make_precision_recall_eval( + f1_raw, + na_probs, + num_true_pos, + qid_to_has_ans, + out_image=os.path.join(out_image_dir, "pr_f1.png"), + title="Precision-Recall curve for F1 score", + ) + oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} + pr_oracle = make_precision_recall_eval( + oracle_scores, + na_probs, + num_true_pos, + qid_to_has_ans, + out_image=os.path.join(out_image_dir, "pr_oracle.png"), + title="Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)", + ) + merge_eval(main_eval, pr_exact, "pr_exact") + merge_eval(main_eval, pr_f1, "pr_f1") + merge_eval(main_eval, pr_oracle, "pr_oracle") + + +def histogram_na_prob(na_probs, qid_list, image_dir, name): + if not qid_list: + return + x = [na_probs[k] for k in qid_list] + weights = np.ones_like(x) / float(len(x)) + plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) + plt.xlabel("Model probability of no-answer") + plt.ylabel("Proportion of dataset") + plt.title(f"Histogram of no-answer probability: {name}") + plt.savefig(os.path.join(image_dir, f"na_prob_hist_{name}.png")) + plt.clf() + + +def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): + num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) + cur_score = num_no_ans + best_score = cur_score + best_thresh = 0.0 + qid_list = sorted(na_probs, key=lambda k: na_probs[k]) + for i, qid in enumerate(qid_list): + if qid not in scores: + continue + if qid_to_has_ans[qid]: + diff = scores[qid] + else: + if preds[qid]: + diff = -1 + else: + diff = 0 + cur_score += diff + if cur_score > best_score: + best_score = cur_score + best_thresh = na_probs[qid] + return 100.0 * best_score / len(scores), best_thresh + + +def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): + best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) + best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) + main_eval["best_exact"] = best_exact + main_eval["best_exact_thresh"] = exact_thresh + main_eval["best_f1"] = best_f1 + main_eval["best_f1_thresh"] = f1_thresh + + +def main(): + with open(OPTS.data_file) as f: + dataset_json = json.load(f) + dataset = dataset_json["data"] + with open(OPTS.pred_file) as f: + preds = json.load(f) + if OPTS.na_prob_file: + with open(OPTS.na_prob_file) as f: + na_probs = json.load(f) + else: + na_probs = {k: 0.0 for k in preds} + qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False + has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] + no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] + exact_raw, f1_raw = get_raw_scores(dataset, preds) + exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh) + f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh) + out_eval = make_eval_dict(exact_thresh, f1_thresh) + if has_ans_qids: + has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) + merge_eval(out_eval, has_ans_eval, "HasAns") + if no_ans_qids: + no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) + merge_eval(out_eval, no_ans_eval, "NoAns") + if OPTS.na_prob_file: + find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) + if OPTS.na_prob_file and OPTS.out_image_dir: + run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, OPTS.out_image_dir) + histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, "hasAns") + histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, "noAns") + if OPTS.out_file: + with open(OPTS.out_file, "w") as f: + json.dump(out_eval, f) + else: + print(json.dumps(out_eval, indent=2)) + + +if __name__ == "__main__": + OPTS = parse_args() + if OPTS.out_image_dir: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + main() \ No newline at end of file diff --git a/diana/downstream/metric/squad_v2_local/squad_v2_local.py b/diana/downstream/metric/squad_v2_local/squad_v2_local.py new file mode 100644 index 0000000..d88979a --- /dev/null +++ b/diana/downstream/metric/squad_v2_local/squad_v2_local.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" SQuAD v2 metric. """ + +import datasets + +from .evaluate import ( + apply_no_ans_threshold, + find_all_best_thresh, + get_raw_scores, + make_eval_dict, + make_qid_to_has_ans, + merge_eval, +) + + +_CITATION = """\ +@inproceedings{Rajpurkar2016SQuAD10, + title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text}, + author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang}, + booktitle={EMNLP}, + year={2016} +} +""" + +_DESCRIPTION = """ +This metric wrap the official scoring script for version 2 of the Stanford Question +Answering Dataset (SQuAD). + +Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by +crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, +from the corresponding reading passage, or the question might be unanswerable. + +SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable questions +written adversarially by crowdworkers to look similar to answerable ones. +To do well on SQuAD2.0, systems must not only answer questions when possible, but also +determine when no answer is supported by the paragraph and abstain from answering. +""" + +_KWARGS_DESCRIPTION = """ +Computes SQuAD v2 scores (F1 and EM). +Args: + predictions: List of triple for question-answers to score with the following elements: + - the question-answer 'id' field as given in the references (see below) + - the text of the answer + - the probability that the question has no answer + references: List of question-answers dictionaries with the following key-values: + - 'id': id of the question-answer pair (see above), + - 'answers': a list of Dict {'text': text of the answer as a string} + no_answer_threshold: float + Probability threshold to decide that a question has no answer. +Returns: + 'exact': Exact match (the normalized answer exactly match the gold answer) + 'f1': The F-score of predicted tokens versus the gold answer + 'total': Number of score considered + 'HasAns_exact': Exact match (the normalized answer exactly match the gold answer) + 'HasAns_f1': The F-score of predicted tokens versus the gold answer + 'HasAns_total': Number of score considered + 'NoAns_exact': Exact match (the normalized answer exactly match the gold answer) + 'NoAns_f1': The F-score of predicted tokens versus the gold answer + 'NoAns_total': Number of score considered + 'best_exact': Best exact match (with varying threshold) + 'best_exact_thresh': No-answer probability threshold associated to the best exact match + 'best_f1': Best F1 (with varying threshold) + 'best_f1_thresh': No-answer probability threshold associated to the best F1 +Examples: + + >>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22', 'no_answer_probability': 0.}] + >>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}] + >>> squad_v2_metric = datasets.load_metric("squad_v2") + >>> results = squad_v2_metric.compute(predictions=predictions, references=references) + >>> print(results) + {'exact': 100.0, 'f1': 100.0, 'total': 1, 'HasAns_exact': 100.0, 'HasAns_f1': 100.0, 'HasAns_total': 1, 'best_exact': 100.0, 'best_exact_thresh': 0.0, 'best_f1': 100.0, 'best_f1_thresh': 0.0} +""" + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class SquadV2Local(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": { + "id": datasets.Value("string"), + "prediction_text": datasets.Value("string"), + "no_answer_probability": datasets.Value("float32"), + }, + "references": { + "id": datasets.Value("string"), + "answers": datasets.features.Sequence( + {"text": datasets.Value("string"), "answer_start": datasets.Value("int32")} + ), + }, + } + ), + codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], + reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], + ) + + def _compute(self, predictions, references, no_answer_threshold=1.0): + # no_answer_probabilities = dict((p["id"], p["no_answer_probability"]) for p in predictions) + dataset = [{"paragraphs": [{"qas": references}]}] + predictions = dict((p["id"], p["prediction_text"]) for p in predictions) + + # qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False + # has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] + # no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] + + exact_raw, f1_raw = get_raw_scores(dataset, predictions) + out_eval = make_eval_dict(exact_raw, f1_raw) + # + # + # exact_thresh = apply_no_ans_threshold(exact_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold) + # f1_thresh = apply_no_ans_threshold(f1_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold) + # out_eval = make_eval_dict(exact_thresh, f1_thresh) + # + # if has_ans_qids: + # has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) + # merge_eval(out_eval, has_ans_eval, "HasAns") + # if no_ans_qids: + # no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) + # merge_eval(out_eval, no_ans_eval, "NoAns") + # find_all_best_thresh(out_eval, predictions, exact_raw, f1_raw, no_answer_probabilities, qid_to_has_ans) + return dict(out_eval) \ No newline at end of file diff --git a/diana/downstream/simple_processors.py b/diana/downstream/simple_processors.py new file mode 100755 index 0000000..6505679 --- /dev/null +++ b/diana/downstream/simple_processors.py @@ -0,0 +1,192 @@ +import sys +sys.path.append('../data_process') +from typing import List, Optional, Tuple +from QAInput import SimpleQAInput as QAInput +#from QAInput import QAInputNoPrompt as QAInput + +def preprocess_sqaud_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_sqaud_abstractive_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_boolq_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + question_column, context_column, answer_column = 'question', 'passage', 'answer' + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)] + targets = [str(ans) for ans in answers] #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + # print(inputs,targets) + return inputs, targets + +def preprocess_boolq_batch_pretrain( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0].capitalize() if len(answer["text"]) > 0 else "" for answer in answers] + # print(inputs,targets) + return inputs, targets + +def preprocess_narrativeqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = [exp['summary']['text'] for exp in examples['document']] + questions = [exp['text'] for exp in examples['question']] + answers = [ans[0]['text'] for ans in examples['answers']] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = answers #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_narrativeqa_batch_pretrain( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_drop_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = examples['passage'] + questions = examples['question'] + answers = examples['answers'] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_race_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = examples['article'] + questions = examples['question'] + all_options = examples['options'] + answers = examples['answer'] + options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}; D. {options[3]}' for options in all_options] + inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)] + ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3 } + targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)] + return inputs, targets + +def preprocess_newsqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)] + # inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + + +def preprocess_ropes_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + backgrounds = examples["background"] + situations = examples["situation"] + answers = examples[answer_column] + + inputs = [QAInput.qg_input_extractive_qa(" ".join([background, situation]), question) for question, background, situation in zip(questions, backgrounds, situations)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_openbookqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples['question_stem'] + all_options = examples['choices'] + answers = examples['answerKey'] + options_texts = [f"options: A. {options['text'][0]}; B. {options['text'][1]}; C. {options['text'][2]}; D. {options['text'][3]}" for options in all_options] + inputs = [QAInput.qg_input_multirc("", question, ops) for question, ops in zip(questions, options_texts)] + ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3} + targets = [options['text'][ans_map[answer]] for options, answer in zip(all_options, answers)] + return inputs, targets + +def preprocess_social_iqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = examples['article'] + questions = examples['question'] + all_options = examples['options'] + answers = examples['answer'] + options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options] + inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)] + ans_map = {'A': 0, 'B': 1, 'C': 2,} + targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)] + return inputs, targets + +def preprocess_dream_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = [" ".join(dialogue) for dialogue in examples['dialogue']] + questions = examples['question'] + all_options = examples['choice'] + answers = examples['answer'] + answer_idxs = [options.index(answer) for answer, options in zip(answers, all_options)] + options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options] + inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)] + targets = answers + return inputs, targets \ No newline at end of file diff --git a/diana/downstreamdeca/QAInput.py b/diana/downstreamdeca/QAInput.py new file mode 100644 index 0000000..7c45e12 --- /dev/null +++ b/diana/downstreamdeca/QAInput.py @@ -0,0 +1,21 @@ +#ProQA +#[paper]https://arxiv.org/abs/2205.04040 +class StructuralQAInput: + @classmethod + def qg_input(cls, context, question, options=None): + question = question + source_text = f'[QUESTION] {question}. [CONTEXT] {context} the answer is: ' + return source_text + + + + + + +class SimpleQAInput: + @classmethod + def qg_input(cls, context, question, options=None): + question = question + source_text = f'{question} \\n {context}' + return source_text + diff --git a/diana/downstreamdeca/dataset_processors.py b/diana/downstreamdeca/dataset_processors.py new file mode 100755 index 0000000..2b056b5 --- /dev/null +++ b/diana/downstreamdeca/dataset_processors.py @@ -0,0 +1,192 @@ +import sys +sys.path.append('../data_process') +from typing import List, Optional, Tuple +from QAInput import StructuralQAInput as QAInput +#from QAInput import QAInputNoPrompt as QAInput + +def preprocess_sqaud_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_sqaud_abstractive_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_boolq_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + question_column, context_column, answer_column = 'question', 'passage', 'answer' + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)] + targets = [str(ans) for ans in answers] #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + # print(inputs,targets) + return inputs, targets + +def preprocess_boolq_batch_pretrain( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0].capitalize() if len(answer["text"]) > 0 else "" for answer in answers] + # print(inputs,targets) + return inputs, targets + +def preprocess_narrativeqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = [exp['summary']['text'] for exp in examples['document']] + questions = [exp['text'] for exp in examples['question']] + answers = [ans[0]['text'] for ans in examples['answers']] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = answers #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_narrativeqa_batch_pretrain( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_drop_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = examples['passage'] + questions = examples['question'] + answers = examples['answers'] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_race_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = examples['article'] + questions = examples['question'] + all_options = examples['options'] + answers = examples['answer'] + options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}; D. {options[3]}' for options in all_options] + inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)] + ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3 } + targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)] + return inputs, targets + +def preprocess_newsqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)] + # inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + + +def preprocess_ropes_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + backgrounds = examples["background"] + situations = examples["situation"] + answers = examples[answer_column] + + inputs = [QAInput.qg_input_extractive_qa(" ".join([background, situation]), question) for question, background, situation in zip(questions, backgrounds, situations)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_openbookqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples['question_stem'] + all_options = examples['choices'] + answers = examples['answerKey'] + options_texts = [f"options: A. {options['text'][0]}; B. {options['text'][1]}; C. {options['text'][2]}; D. {options['text'][3]}" for options in all_options] + inputs = [QAInput.qg_input_multirc("", question, ops) for question, ops in zip(questions, options_texts)] + ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3} + targets = [options['text'][ans_map[answer]] for options, answer in zip(all_options, answers)] + return inputs, targets + +def preprocess_social_iqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = examples['article'] + questions = examples['question'] + all_options = examples['options'] + answers = examples['answer'] + options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options] + inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)] + ans_map = {'A': 0, 'B': 1, 'C': 2,} + targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)] + return inputs, targets + +def preprocess_dream_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = [" ".join(dialogue) for dialogue in examples['dialogue']] + questions = examples['question'] + all_options = examples['choice'] + answers = examples['answer'] + answer_idxs = [options.index(answer) for answer, options in zip(answers, all_options)] + options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options] + inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)] + targets = answers + return inputs, targets \ No newline at end of file diff --git a/diana/downstreamdeca/diana.py b/diana/downstreamdeca/diana.py new file mode 100644 index 0000000..e3b8a5c --- /dev/null +++ b/diana/downstreamdeca/diana.py @@ -0,0 +1,801 @@ +#!/usr/bin/env python +# coding=utf-8 + +import logging +import os +import torch +import copy,random +import sys +import json +from dataclasses import dataclass, field +from typing import Optional +from typing import List, Optional, Tuple +from sklearn.cluster import KMeans +# from data import QAData +sys.path.append("../") +from models.metadeca import T5ForConditionalGeneration as PromptT5 +from metrics import compute_metrics +from downstreamdeca.simple_processors import * +from downstreamdeca.l2ptrainer import QuestionAnsweringTrainer +import datasets +import numpy as np +from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets +import os + +from functools import partial +os.environ["WANDB_DISABLED"] = "true" +import transformers +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + M2M100Tokenizer, + MBart50Tokenizer, + EvalPrediction, + MBart50TokenizerFast, + MBartTokenizer, + MBartTokenizerFast, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + default_data_collator, + set_seed, +) +from pathlib import Path +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +logger = logging.getLogger(__name__) + +# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes. +MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer] +format2id = {'extractive':0,'abstractive':1,'multichoice':2} +format2dataset = { + 'extractive':['squad','extractive','newsqa','quoref'], #,'ropes' + 'abstractive':['narrativeqa','abstractive','nqopen','drop'], + 'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream'] +} +seed_datasets = ['race','narrativeqa','squad'] +dataset2format= {} +task2id = {} +for k,vs in format2dataset.items(): + for v in vs: + dataset2format[v] = k + + +task2id = {"wikisql":0,"squad":1,"srl":2,"sst":3,"woz.en":4,"iwslt.en.de":5,"cnn_dailymail":6,"multinli.in.out":7,"zre":8} +task2format={"wikisql":1,"squad":0,"srl":0,"sst":2,"woz.en":1,"iwslt.en.de":0,"cnn_dailymail":1,"multinli.in.out":2,"zre":0} + + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + fix_t5: Optional[bool] = field( + default=False, + metadata={"help": "whether to fix the main parameters of t5"} + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + source_lang: str = field(default=None, metadata={"help": "Source language id for translation."}) + target_lang: str = field(default=None, metadata={"help": "Target language id for translation."}) + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."}) + validation_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on " + "a jsonlines file." + }, + ) + test_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + after_train: Optional[str] = field(default=None) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + }, + ) + max_seq_length: int = field( + default=384, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + source_prefix: Optional[str] = field( + default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + forced_bos_token: Optional[str] = field( + default=None, + metadata={ + "help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`." + "Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token " + "needs to be the target language token.(Usually it is the target language token)" + }, + ) + do_lowercase: bool = field( + default=False, + metadata={ + 'help':'Whether to process input into lowercase' + } + ) + append_another_bos: bool = field( + default=False, + metadata={ + 'help':'Whether to add another bos' + } + ) + context_column: Optional[str] = field( + default="context", + metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."}, + ) + question_column: Optional[str] = field( + default="question", + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + answer_column: Optional[str] = field( + default="answers", + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + max_task_num: Optional[int] = field( + default=30, + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + qa_task_type_num: Optional[int] = field( + default=4, + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + prompt_number: Optional[int] = field( + default=100, + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + add_task_prompt: Optional[bool] = field( + default=False, + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + reload_from_trained_prompt: Optional[bool] = field( + default=False, + metadata={"help": "whether to reload prompt from trained prompt"}, + ) + trained_prompt_path:Optional[str] = field( + default = None, + metadata={ + 'help':'the path storing trained prompt embedding' + } + ) + load_from_format_task_id: Optional[bool] = field( + default=False, + metadata={"help": "whether to reload prompt from format-corresponding task prompt"} + ) + + + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + # elif self.source_lang is None or self.target_lang is None: + # raise ValueError("Need to specify the source language and the target language.") + + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + # assert extension == "json", "`train_file` should be a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + # assert extension == "json", "`validation_file` should be a json file." + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + + +# + +def main(): + + + + def preprocess_validation_function(examples): + preprocess_fn = dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column) + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + # Setup the tokenizer for targets + model_inputs["example_id"] = [] + + for i in range(len(model_inputs["input_ids"])): + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = i #sample_mapping[i] + model_inputs["example_id"].append(examples["id"][sample_index]) + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + + + def save_prompt_embedding(model,path): + prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight'] + save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id} + prompt_path = os.path.join(path,'prompt_embedding_info') + torch.save(save_prompt_info,prompt_path) + logger.info(f'Saving prompt embedding information to {prompt_path}') + + + + def preprocess_function(examples): + preprocess_fn = dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column) + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + def save_load_diverse_sample(model,trainset): + with torch.no_grad(): + device = 'cuda:0' + search_upbound = len(trainset)//4 + query_idxs = [None]*30 + keys = model.encoder.domain_keys + for idx,item in enumerate(trainset.select(range(search_upbound))): + query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device), + attention_mask=torch.tensor([item['attention_mask']]).long().to(device), + return_dict=True) + result = torch.matmul(query,keys.t()) + result = torch.topk(result,5).indices[0].cpu().numpy().tolist() + key_sel = None + for key_idx in result: + if query_idxs[key_idx] is None or len(query_idxs[key_idx])<3: + key_sel = key_idx + break + if key_sel is not None: + if query_idxs[key_sel] is None: + query_idxs[key_sel] = [idx] + else: + query_idxs[key_sel].append(idx) + total_idxs = [] + for item in query_idxs: + try: + total_idxs.extend(item[:3]) + except: + total_idxs.extend(random.sample(list(range(search_upbound,len(trainset))),3)) + total_idxs = list(set(total_idxs)) + total_idxs = random.sample(total_idxs,50) + sub_set = trainset.select(total_idxs) + + features = [] + for idx,item in enumerate(sub_set): + query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device), + attention_mask=torch.tensor([item['attention_mask']]).long().to(device), + return_dict=True) + features.append(query.detach().cpu().numpy()) + + + + return sub_set,features + + + + + + def dataset_name_to_func(dataset_name): + mapping = { + 'squad': preprocess_sqaud_batch, + 'squad_v2': preprocess_sqaud_batch, + 'boolq': preprocess_boolq_batch, + 'narrativeqa': preprocess_narrativeqa_batch, + 'race': preprocess_race_batch, + 'newsqa': preprocess_newsqa_batch, + 'quoref': preprocess_sqaud_batch, + 'ropes': preprocess_ropes_batch, + 'drop': preprocess_drop_batch, + 'nqopen': preprocess_sqaud_abstractive_batch, + # 'multirc': preprocess_boolq_batch, + 'boolq_np': preprocess_boolq_batch, + 'openbookqa': preprocess_openbookqa_batch, + 'mctest': preprocess_race_batch, + 'social_iqa': preprocess_social_iqa_batch, + 'dream': preprocess_dream_batch, + } + return mapping[dataset_name] + + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + dataset_name_to_metric = { + 'squad': 'metric/squad_v1_local/squad_v1_local.py', + 'squad_v2': 'metric/squad_v2_local/squad_v2_local.py', + 'newsqa': 'metric/squad_v2_local/squad_v2_local.py', + 'boolq': 'metric/accuracy.py', + 'narrativeqa': 'metric/rouge_local/rouge_metric.py', + 'race': 'metric/accuracy.py', + 'quoref': 'metric/squad_v1_local/squad_v1_local.py', + 'ropes': 'metric/squad_v1_local/squad_v1_local.py', + 'drop': 'metric/squad_v1_local/squad_v1_local.py', + 'nqopen': 'metric/squad_v1_local/squad_v1_local.py', + 'boolq_np': 'metric/accuracy.py', + 'openbookqa': 'metric/accuracy.py', + 'mctest': 'metric/accuracy.py', + 'social_iqa': 'metric/accuracy.py', + 'dream': 'metric/accuracy.py', + } + + + + + + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + if data_args.source_prefix is None and model_args.model_name_or_path in [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + ]: + logger.warning( + "You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with " + "`--source_prefix 'translate English to German: ' `" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]', + # '[OPTIONS]']) + tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]'] + special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]', + '[OPTIONS]']} + + tokenizer.add_tokens(tokens_to_add) + num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) + added_tokens = tokenizer.get_added_vocab() + logger.info('Added tokens: {}'.format(added_tokens)) + + model = PromptT5.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + # task_num = data_args.max_task_num, + # prompt_num = data_args.prompt_number, + # format_num = data_args.qa_task_type_num, + # add_task_prompt = False + ) + + model.resize_token_embeddings(len(tokenizer)) + + + #reload format specific task-prompt for newly involved task +#format_prompts###task_promptsf + + data_args.reload_from_trained_prompt = False#@ + data_args.load_from_format_task_id = False#@ + ### before pretrain come !!!!!! + + if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt: + task_start_id = data_args.prompt_number * len(format2dataset.keys()) + task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number + format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number + model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:] + logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}') + # print(dataset2format[data_args.dataset_name]) + # print(data_args.dataset_name) + elif data_args.reload_from_trained_prompt: + assert data_args.trained_prompt_path,'Must specify the path of stored prompt' + prompt_info = torch.load(data_args.trained_prompt_path) + assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}' + assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match' + task_start_id = data_args.prompt_number * len(format2dataset.keys()) + + task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number + logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number)) + # assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0 + model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] + format_id = format2id[dataset2format[data_args.dataset_name]] + model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] + logger.info( + f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}') + + # Set decoder_start_token_id + if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): + if isinstance(tokenizer, MBartTokenizer): + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang] + else: + model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang) + + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + + + if training_args.local_rank == -1 or training_args.no_cuda: + device = torch.device("cuda") + n_gpu = torch.cuda.device_count() + + + + + + # Temporarily set max_target_length for training. + max_target_length = data_args.max_target_length + padding = "max_length" if data_args.pad_to_max_length else False + + if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): + logger.warning( + "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" + f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" + ) + + + question_column = data_args.question_column + context_column = data_args.context_column + answer_column = data_args.answer_column + # import random + if data_args.max_source_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length) + # Data collator + label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + if data_args.pad_to_max_length: + data_collator = default_data_collator + else: + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=8 if training_args.fp16 else None, + ) + + #start + train_dataloaders = {} + eval_dataloaders = {} + replay_dataloaders = {} + all_replay = None + + for ds_name in ["squad","wikisql","sst","srl","woz.en"]: + + if True: + cur_dataset = ds_name + train_dataloaders[ds_name] = load_from_disk("./oursdeca/{}-train.hf".format(cur_dataset)).select(range(200)) + eval_dataloaders[ds_name] = (load_from_disk("./oursdeca/{}-eval.hf".format(cur_dataset)).select(range(200)),load_from_disk("./oursdeca/{}-evalex.hf".format(cur_dataset)).select(range(200))) + + for ds_name in ["cnn_dailymail","multinli.in.out","zre"]: + eval_dataloaders[ds_name] = (load_from_disk("./oursdeca/{}-eval.hf".format(ds_name)),load_from_disk("./oursdeca/{}-evalex.hf".format(ds_name))) + + pre_tasks = [] + pre_general = [] + pre_test = [] + max_length = ( + training_args.generation_max_length + if training_args.generation_max_length is not None + else data_args.val_max_target_length + ) + num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams + + task_sequence = ["squad","wikisql","sst","srl","woz.en"] + # task_sequence = ["woz.en","srl","sst","wikisql","squad"] + need_to_do_dss = [] + fileout = open("diana_log.txt",'w') + all_replay = None + all_features = [] + all_ids = [] + cluster_num=0 + for cur_dataset in task_sequence: + cluster_num+=5 + pre_tasks.append(cur_dataset) + if cur_dataset==task_sequence[-1]: + pre_tasks.extend(["cnn_dailymail","multinli.in.out","zre"]) + data_args.dataset_name = cur_dataset + logger.info("current_dataset:"+cur_dataset) + + + training_args.do_train = True + training_args.to_eval = False + metric = load_metric("metric/squad_v1_local/squad_v1_local.py") + if all_replay is not None: + fused = datasets.concatenate_datasets([all_replay,train_dataloaders[cur_dataset]]) + else: + fused = train_dataloaders[cur_dataset] + training_args.num_train_epochs = 5 + model.encoder.reset_train_count() + + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=fused, + eval_dataset=None, + eval_examples=None, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + train_result = trainer.train() + + if training_args.local_rank<=0: + save_set,features = save_load_diverse_sample(model,train_dataloaders[cur_dataset]) + + if all_replay is None: + all_replay = save_set + else: + all_replay = datasets.concatenate_datasets([all_replay,save_set]) + + if all_features==[]: + all_features=features + else: + all_features.extend(features) + np.save("./all_features.npy",np.array(all_features)) + + all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset)) + + + + if training_args.local_rank!=-1: + torch.distributed.barrier() + all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset)) + all_ids.extend([task2id[cur_dataset]]*50) + all_features=np.load("./all_features.npy").tolist() + + model.encoder.add_negs(all_ids,all_features) + + + + + + + + + for pre_dataset in pre_tasks: + + + data_args.dataset_name = pre_dataset + + metric = load_metric("metric/squad_v1_local/squad_v1_local.py") + eval_dataset,eval_examples = eval_dataloaders[pre_dataset] + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=None, + eval_dataset=eval_dataset, + eval_examples=eval_examples, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + + + + torch.cuda.empty_cache() + logger.info("*** Evaluate:{} ***".format(data_args.dataset_name)) + + max_length, num_beams, ignore_keys_for_eval = None, None, None + metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval") + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + if training_args.local_rank<=0: + try: + print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout) + print(metrics,file=fileout) + except: + pass + + + + + + + languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None] + if len(languages) > 0: + kwargs["language"] = languages + return None + + +# - + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/diana/downstreamdeca/diana.sh b/diana/downstreamdeca/diana.sh new file mode 100644 index 0000000..80917d8 --- /dev/null +++ b/diana/downstreamdeca/diana.sh @@ -0,0 +1,53 @@ +function fulldata_hfdata() { +SAVING_PATH=$1 +PRETRAIN_MODEL_PATH=$2 +TRAIN_BATCH_SIZE=$3 +EVAL_BATCH_SIZE=$4 + + + + + +mkdir -p ${SAVING_PATH} + +python -m torch.distributed.launch --nproc_per_node=4 diana.py \ + --model_name_or_path ${PRETRAIN_MODEL_PATH} \ + --output_dir ${SAVING_PATH} \ + --dataset_name squad \ + --do_train \ + --do_eval \ + --per_device_train_batch_size ${TRAIN_BATCH_SIZE} \ + --per_device_eval_batch_size ${EVAL_BATCH_SIZE} \ + --overwrite_output_dir \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 5 \ + --warmup_ratio 0.1 \ + --logging_steps 5 \ + --learning_rate 1e-4 \ + --predict_with_generate \ + --num_beams 4 \ + --save_strategy no \ + --evaluation_strategy no \ + --weight_decay 1e-2 \ + --max_source_length 512 \ + --label_smoothing_factor 0.1 \ + --do_lowercase True \ + --load_best_model_at_end True \ + --greater_is_better True \ + --save_total_limit 10 \ + --ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log +} + +SAVING_PATH=$1 +TRAIN_BATCH_SIZE=$2 +EVAL_BATCH_SIZE=$3 + + + + + +MODE=try + + +SAVING_PATH=${SAVING_PATH}/lifelong/${MODE} +fulldata_hfdata ${SAVING_PATH} t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE} diff --git a/diana/downstreamdeca/diana_log.txt b/diana/downstreamdeca/diana_log.txt new file mode 100644 index 0000000..d4bfcde --- /dev/null +++ b/diana/downstreamdeca/diana_log.txt @@ -0,0 +1,6 @@ +after_train_ squad _test_ squad +OrderedDict([('eval_em', 88.0), ('eval_nf1', 91.79166666666667), ('eval_nem', 89.5), ('eval_samples', 200)]) +after_train_ wikisql _test_ squad +OrderedDict([('eval_em', 83.0), ('eval_nf1', 87.73095238095239), ('eval_nem', 84.5), ('eval_samples', 200)]) +after_train_ wikisql _test_ wikisql +OrderedDict([('eval_lfem', 30.5), ('eval_em', 17.5), ('eval_nf1', 84.06011500628662), ('eval_nem', 20.0), ('eval_samples', 200)]) diff --git a/diana/downstreamdeca/dianawometa.py b/diana/downstreamdeca/dianawometa.py new file mode 100644 index 0000000..c1419a0 --- /dev/null +++ b/diana/downstreamdeca/dianawometa.py @@ -0,0 +1,775 @@ +#!/usr/bin/env python +# coding=utf-8 + +import logging +import os +import torch +import copy,random +import sys +import json +from dataclasses import dataclass, field +from typing import Optional +from typing import List, Optional, Tuple +from sklearn.cluster import KMeans +# from data import QAData +sys.path.append("../") +from models.metadecanometa import T5ForConditionalGeneration as PromptT5 +from metrics import compute_metrics +from downstreamdeca.simple_processors import * +from downstreamdeca.l2ptrainer import QuestionAnsweringTrainer +import datasets +import numpy as np +from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets +import os + +from functools import partial +os.environ["WANDB_DISABLED"] = "true" +import transformers +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + M2M100Tokenizer, + MBart50Tokenizer, + EvalPrediction, + MBart50TokenizerFast, + MBartTokenizer, + MBartTokenizerFast, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + default_data_collator, + set_seed, +) +from pathlib import Path +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +logger = logging.getLogger(__name__) + +# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes. +MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer] +format2id = {'extractive':0,'abstractive':1,'multichoice':2} +format2dataset = { + 'extractive':['squad','extractive','newsqa','quoref'], #,'ropes' + 'abstractive':['narrativeqa','abstractive','nqopen','drop'], + 'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream'] +} +seed_datasets = ['race','narrativeqa','squad'] +dataset2format= {} +task2id = {} +for k,vs in format2dataset.items(): + for v in vs: + dataset2format[v] = k + + +task2id = {"wikisql":0,"squad":1,"srl":2,"sst":3,"woz.en":4,"iwslt.en.de":5,"cnn_dailymail":6,"multinli.in.out":7,"zre":8} +task2format={"wikisql":1,"squad":0,"srl":0,"sst":2,"woz.en":1,"iwslt.en.de":0,"cnn_dailymail":1,"multinli.in.out":2,"zre":0} + + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + fix_t5: Optional[bool] = field( + default=False, + metadata={"help": "whether to fix the main parameters of t5"} + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + source_lang: str = field(default=None, metadata={"help": "Source language id for translation."}) + target_lang: str = field(default=None, metadata={"help": "Target language id for translation."}) + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."}) + validation_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on " + "a jsonlines file." + }, + ) + test_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + after_train: Optional[str] = field(default=None) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + }, + ) + max_seq_length: int = field( + default=384, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + source_prefix: Optional[str] = field( + default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + forced_bos_token: Optional[str] = field( + default=None, + metadata={ + "help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`." + "Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token " + "needs to be the target language token.(Usually it is the target language token)" + }, + ) + do_lowercase: bool = field( + default=False, + metadata={ + 'help':'Whether to process input into lowercase' + } + ) + append_another_bos: bool = field( + default=False, + metadata={ + 'help':'Whether to add another bos' + } + ) + context_column: Optional[str] = field( + default="context", + metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."}, + ) + question_column: Optional[str] = field( + default="question", + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + answer_column: Optional[str] = field( + default="answers", + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + max_task_num: Optional[int] = field( + default=30, + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + qa_task_type_num: Optional[int] = field( + default=4, + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + prompt_number: Optional[int] = field( + default=100, + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + add_task_prompt: Optional[bool] = field( + default=False, + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + reload_from_trained_prompt: Optional[bool] = field( + default=False, + metadata={"help": "whether to reload prompt from trained prompt"}, + ) + trained_prompt_path:Optional[str] = field( + default = None, + metadata={ + 'help':'the path storing trained prompt embedding' + } + ) + load_from_format_task_id: Optional[bool] = field( + default=False, + metadata={"help": "whether to reload prompt from format-corresponding task prompt"} + ) + + + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + # elif self.source_lang is None or self.target_lang is None: + # raise ValueError("Need to specify the source language and the target language.") + + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + # assert extension == "json", "`train_file` should be a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + # assert extension == "json", "`validation_file` should be a json file." + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + + +# + +def main(): + + + + def preprocess_validation_function(examples): + preprocess_fn = dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column) + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + # Setup the tokenizer for targets + model_inputs["example_id"] = [] + + for i in range(len(model_inputs["input_ids"])): + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = i #sample_mapping[i] + model_inputs["example_id"].append(examples["id"][sample_index]) + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + + + def save_prompt_embedding(model,path): + prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight'] + save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id} + prompt_path = os.path.join(path,'prompt_embedding_info') + torch.save(save_prompt_info,prompt_path) + logger.info(f'Saving prompt embedding information to {prompt_path}') + + + + def preprocess_function(examples): + preprocess_fn = dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column) + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + def save_load_diverse_sample(model,trainset): + with torch.no_grad(): + device = 'cuda:0' + + sub_set = trainset.select(range(50)) + + features = [] + for idx,item in enumerate(sub_set): + query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device), + attention_mask=torch.tensor([item['attention_mask']]).long().to(device), + return_dict=True) + features.append(query.detach().cpu().numpy()) + + + + return sub_set,features + + + + + + def dataset_name_to_func(dataset_name): + mapping = { + 'squad': preprocess_sqaud_batch, + 'squad_v2': preprocess_sqaud_batch, + 'boolq': preprocess_boolq_batch, + 'narrativeqa': preprocess_narrativeqa_batch, + 'race': preprocess_race_batch, + 'newsqa': preprocess_newsqa_batch, + 'quoref': preprocess_sqaud_batch, + 'ropes': preprocess_ropes_batch, + 'drop': preprocess_drop_batch, + 'nqopen': preprocess_sqaud_abstractive_batch, + # 'multirc': preprocess_boolq_batch, + 'boolq_np': preprocess_boolq_batch, + 'openbookqa': preprocess_openbookqa_batch, + 'mctest': preprocess_race_batch, + 'social_iqa': preprocess_social_iqa_batch, + 'dream': preprocess_dream_batch, + } + return mapping[dataset_name] + + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + dataset_name_to_metric = { + 'squad': 'metric/squad_v1_local/squad_v1_local.py', + 'squad_v2': 'metric/squad_v2_local/squad_v2_local.py', + 'newsqa': 'metric/squad_v2_local/squad_v2_local.py', + 'boolq': 'metric/accuracy.py', + 'narrativeqa': 'metric/rouge_local/rouge_metric.py', + 'race': 'metric/accuracy.py', + 'quoref': 'metric/squad_v1_local/squad_v1_local.py', + 'ropes': 'metric/squad_v1_local/squad_v1_local.py', + 'drop': 'metric/squad_v1_local/squad_v1_local.py', + 'nqopen': 'metric/squad_v1_local/squad_v1_local.py', + 'boolq_np': 'metric/accuracy.py', + 'openbookqa': 'metric/accuracy.py', + 'mctest': 'metric/accuracy.py', + 'social_iqa': 'metric/accuracy.py', + 'dream': 'metric/accuracy.py', + } + + + + + + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + if data_args.source_prefix is None and model_args.model_name_or_path in [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + ]: + logger.warning( + "You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with " + "`--source_prefix 'translate English to German: ' `" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]', + # '[OPTIONS]']) + tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]'] + special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]', + '[OPTIONS]']} + + tokenizer.add_tokens(tokens_to_add) + num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) + added_tokens = tokenizer.get_added_vocab() + logger.info('Added tokens: {}'.format(added_tokens)) + + model = PromptT5.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + # task_num = data_args.max_task_num, + # prompt_num = data_args.prompt_number, + # format_num = data_args.qa_task_type_num, + # add_task_prompt = False + ) + + model.resize_token_embeddings(len(tokenizer)) + + + #reload format specific task-prompt for newly involved task +#format_prompts###task_promptsf + + data_args.reload_from_trained_prompt = False#@ + data_args.load_from_format_task_id = False#@ + ### before pretrain come !!!!!! + + if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt: + task_start_id = data_args.prompt_number * len(format2dataset.keys()) + task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number + format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number + model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:] + logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}') + # print(dataset2format[data_args.dataset_name]) + # print(data_args.dataset_name) + elif data_args.reload_from_trained_prompt: + assert data_args.trained_prompt_path,'Must specify the path of stored prompt' + prompt_info = torch.load(data_args.trained_prompt_path) + assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}' + assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match' + task_start_id = data_args.prompt_number * len(format2dataset.keys()) + + task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number + logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number)) + # assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0 + model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] + format_id = format2id[dataset2format[data_args.dataset_name]] + model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] + logger.info( + f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}') + + # Set decoder_start_token_id + if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): + if isinstance(tokenizer, MBartTokenizer): + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang] + else: + model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang) + + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + + + if training_args.local_rank == -1 or training_args.no_cuda: + device = torch.device("cuda") + n_gpu = torch.cuda.device_count() + + + + + + # Temporarily set max_target_length for training. + max_target_length = data_args.max_target_length + padding = "max_length" if data_args.pad_to_max_length else False + + if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): + logger.warning( + "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" + f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" + ) + + + question_column = data_args.question_column + context_column = data_args.context_column + answer_column = data_args.answer_column + # import random + if data_args.max_source_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length) + # Data collator + label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + if data_args.pad_to_max_length: + data_collator = default_data_collator + else: + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=8 if training_args.fp16 else None, + ) + + #start + train_dataloaders = {} + eval_dataloaders = {} + replay_dataloaders = {} + all_replay = None + + for ds_name in ["squad","wikisql","sst","srl","woz.en"]: + + if True: + cur_dataset = ds_name + train_dataloaders[ds_name] = load_from_disk("./oursdecanometa/{}-train.hf".format(cur_dataset)).select(range(200)) + eval_dataloaders[ds_name] = (load_from_disk("./oursdecanometa/{}-eval.hf".format(cur_dataset)).select(range(200)),load_from_disk("./oursdecanometa/{}-evalex.hf".format(cur_dataset)).select(range(200))) + + for ds_name in ["cnn_dailymail","multinli.in.out","zre"]: + eval_dataloaders[ds_name] = (load_from_disk("./oursdecanometa/{}-eval.hf".format(ds_name)),load_from_disk("./oursdecanometa/{}-evalex.hf".format(ds_name))) + + pre_tasks = [] + pre_general = [] + pre_test = [] + max_length = ( + training_args.generation_max_length + if training_args.generation_max_length is not None + else data_args.val_max_target_length + ) + num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams + + task_sequence = ["squad","wikisql","sst","srl","woz.en"] + # task_sequence = ["woz.en","srl","sst","wikisql","squad"] + need_to_do_dss = [] + fileout = open("diana_log.txt",'w') + all_replay = None + all_features = [] + all_ids = [] + cluster_num=0 + for cur_dataset in task_sequence: + cluster_num+=5 + pre_tasks.append(cur_dataset) + if cur_dataset==task_sequence[-1]: + pre_tasks.extend(["cnn_dailymail","multinli.in.out","zre"]) + data_args.dataset_name = cur_dataset + logger.info("current_dataset:"+cur_dataset) + + + training_args.do_train = True + training_args.to_eval = False + metric = load_metric("metric/squad_v1_local/squad_v1_local.py") + if all_replay is not None: + fused = datasets.concatenate_datasets([all_replay,train_dataloaders[cur_dataset]]) + else: + fused = train_dataloaders[cur_dataset] + training_args.num_train_epochs = 5 + model.encoder.reset_train_count() + + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=fused, + eval_dataset=None, + eval_examples=None, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + train_result = trainer.train() + + if training_args.local_rank<=0: + save_set,features = save_load_diverse_sample(model,train_dataloaders[cur_dataset]) + + if all_replay is None: + all_replay = save_set + else: + all_replay = datasets.concatenate_datasets([all_replay,save_set]) + + if all_features==[]: + all_features=features + else: + all_features.extend(features) + np.save("./all_features.npy",np.array(all_features)) + + all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset)) + + + + if training_args.local_rank!=-1: + torch.distributed.barrier() + all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset)) + all_ids.extend([task2id[cur_dataset]]*50) + all_features=np.load("./all_features.npy").tolist() + + model.encoder.add_negs(all_ids,all_features) + + + + + + + + + for pre_dataset in pre_tasks: + + + data_args.dataset_name = pre_dataset + + metric = load_metric("metric/squad_v1_local/squad_v1_local.py") + eval_dataset,eval_examples = eval_dataloaders[pre_dataset] + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=None, + eval_dataset=eval_dataset, + eval_examples=eval_examples, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + + + + torch.cuda.empty_cache() + logger.info("*** Evaluate:{} ***".format(data_args.dataset_name)) + + max_length, num_beams, ignore_keys_for_eval = None, None, None + metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval") + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + if training_args.local_rank<=0: + try: + print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout) + print(metrics,file=fileout) + except: + pass + + + + + + + languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None] + if len(languages) > 0: + kwargs["language"] = languages + return None + + +# - + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/diana/downstreamdeca/dianawometa.sh b/diana/downstreamdeca/dianawometa.sh new file mode 100644 index 0000000..ff3a4fb --- /dev/null +++ b/diana/downstreamdeca/dianawometa.sh @@ -0,0 +1,53 @@ +function fulldata_hfdata() { +SAVING_PATH=$1 +PRETRAIN_MODEL_PATH=$2 +TRAIN_BATCH_SIZE=$3 +EVAL_BATCH_SIZE=$4 + + + + + +mkdir -p ${SAVING_PATH} + +python -m torch.distributed.launch --nproc_per_node=4 dianawometa.py \ + --model_name_or_path ${PRETRAIN_MODEL_PATH} \ + --output_dir ${SAVING_PATH} \ + --dataset_name squad \ + --do_train \ + --do_eval \ + --per_device_train_batch_size ${TRAIN_BATCH_SIZE} \ + --per_device_eval_batch_size ${EVAL_BATCH_SIZE} \ + --overwrite_output_dir \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 5 \ + --warmup_ratio 0.1 \ + --logging_steps 5 \ + --learning_rate 1e-4 \ + --predict_with_generate \ + --num_beams 4 \ + --save_strategy no \ + --evaluation_strategy no \ + --weight_decay 1e-2 \ + --max_source_length 512 \ + --label_smoothing_factor 0.1 \ + --do_lowercase True \ + --load_best_model_at_end True \ + --greater_is_better True \ + --save_total_limit 10 \ + --ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log +} + +SAVING_PATH=$1 +TRAIN_BATCH_SIZE=$2 +EVAL_BATCH_SIZE=$3 + + + + + +MODE=try + + +SAVING_PATH=${SAVING_PATH}/lifelong/${MODE} +fulldata_hfdata ${SAVING_PATH} t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE} diff --git a/diana/downstreamdeca/dianawotask.py b/diana/downstreamdeca/dianawotask.py new file mode 100644 index 0000000..bf82e4e --- /dev/null +++ b/diana/downstreamdeca/dianawotask.py @@ -0,0 +1,801 @@ +#!/usr/bin/env python +# coding=utf-8 + +import logging +import os +import torch +import copy,random +import sys +import json +from dataclasses import dataclass, field +from typing import Optional +from typing import List, Optional, Tuple +from sklearn.cluster import KMeans +# from data import QAData +sys.path.append("../") +from models.metadecanotask import T5ForConditionalGeneration as PromptT5 +from metrics import compute_metrics +from downstreamdeca.simple_processors import * +from downstreamdeca.l2ptrainer import QuestionAnsweringTrainer +import datasets +import numpy as np +from datasets import load_dataset, load_metric,load_from_disk,concatenate_datasets +import os + +from functools import partial +os.environ["WANDB_DISABLED"] = "true" +import transformers +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + M2M100Tokenizer, + MBart50Tokenizer, + EvalPrediction, + MBart50TokenizerFast, + MBartTokenizer, + MBartTokenizerFast, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + default_data_collator, + set_seed, +) +from pathlib import Path +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +logger = logging.getLogger(__name__) + +# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes. +MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer] +format2id = {'extractive':0,'abstractive':1,'multichoice':2} +format2dataset = { + 'extractive':['squad','extractive','newsqa','quoref'], #,'ropes' + 'abstractive':['narrativeqa','abstractive','nqopen','drop'], + 'multichoice':['race','multichoice','openbookqa','mctest','social_iqa','dream'] +} +seed_datasets = ['race','narrativeqa','squad'] +dataset2format= {} +task2id = {} +for k,vs in format2dataset.items(): + for v in vs: + dataset2format[v] = k + + +task2id = {"wikisql":0,"squad":1,"srl":2,"sst":3,"woz.en":4,"iwslt.en.de":5,"cnn_dailymail":6,"multinli.in.out":7,"zre":8} +task2format={"wikisql":1,"squad":0,"srl":0,"sst":2,"woz.en":1,"iwslt.en.de":0,"cnn_dailymail":1,"multinli.in.out":2,"zre":0} + + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + fix_t5: Optional[bool] = field( + default=False, + metadata={"help": "whether to fix the main parameters of t5"} + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + source_lang: str = field(default=None, metadata={"help": "Source language id for translation."}) + target_lang: str = field(default=None, metadata={"help": "Target language id for translation."}) + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."}) + validation_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on " + "a jsonlines file." + }, + ) + test_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + after_train: Optional[str] = field(default=None) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + }, + ) + max_seq_length: int = field( + default=384, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + source_prefix: Optional[str] = field( + default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + ) + forced_bos_token: Optional[str] = field( + default=None, + metadata={ + "help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`." + "Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token " + "needs to be the target language token.(Usually it is the target language token)" + }, + ) + do_lowercase: bool = field( + default=False, + metadata={ + 'help':'Whether to process input into lowercase' + } + ) + append_another_bos: bool = field( + default=False, + metadata={ + 'help':'Whether to add another bos' + } + ) + context_column: Optional[str] = field( + default="context", + metadata={"help": "The name of the column in the datasets containing the contexts (for question answering)."}, + ) + question_column: Optional[str] = field( + default="question", + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + answer_column: Optional[str] = field( + default="answers", + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + max_task_num: Optional[int] = field( + default=30, + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + qa_task_type_num: Optional[int] = field( + default=4, + metadata={"help": "The name of the column in the datasets containing the questions (for question answering)."}, + ) + prompt_number: Optional[int] = field( + default=100, + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + add_task_prompt: Optional[bool] = field( + default=False, + metadata={"help": "The name of the column in the datasets containing the answers (for question answering)."}, + ) + reload_from_trained_prompt: Optional[bool] = field( + default=False, + metadata={"help": "whether to reload prompt from trained prompt"}, + ) + trained_prompt_path:Optional[str] = field( + default = None, + metadata={ + 'help':'the path storing trained prompt embedding' + } + ) + load_from_format_task_id: Optional[bool] = field( + default=False, + metadata={"help": "whether to reload prompt from format-corresponding task prompt"} + ) + + + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + # elif self.source_lang is None or self.target_lang is None: + # raise ValueError("Need to specify the source language and the target language.") + + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + # assert extension == "json", "`train_file` should be a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + # assert extension == "json", "`validation_file` should be a json file." + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + + +# + +def main(): + + + + def preprocess_validation_function(examples): + preprocess_fn = dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column) + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + # Setup the tokenizer for targets + model_inputs["example_id"] = [] + + for i in range(len(model_inputs["input_ids"])): + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = i #sample_mapping[i] + model_inputs["example_id"].append(examples["id"][sample_index]) + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + + + def save_prompt_embedding(model,path): + prompt_embedding = model.state_dict()['encoder.prompt_embeddings.weight'] + save_prompt_info = {'encoder.prompt_embeddings.weight':copy.deepcopy(prompt_embedding),'task2id':task2id,'format2id':format2id} + prompt_path = os.path.join(path,'prompt_embedding_info') + torch.save(save_prompt_info,prompt_path) + logger.info(f'Saving prompt embedding information to {prompt_path}') + + + + def preprocess_function(examples): + preprocess_fn = dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, question_column, context_column, answer_column) + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=data_args.max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + def save_load_diverse_sample(model,trainset): + with torch.no_grad(): + device = 'cuda:0' + search_upbound = len(trainset)//4 + query_idxs = [None]*30 + keys = model.encoder.domain_keys + for idx,item in enumerate(trainset.select(range(search_upbound))): + query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device), + attention_mask=torch.tensor([item['attention_mask']]).long().to(device), + return_dict=True) + result = torch.matmul(query,keys.t()) + result = torch.topk(result,5).indices[0].cpu().numpy().tolist() + key_sel = None + for key_idx in result: + if query_idxs[key_idx] is None or len(query_idxs[key_idx])<3: + key_sel = key_idx + break + if key_sel is not None: + if query_idxs[key_sel] is None: + query_idxs[key_sel] = [idx] + else: + query_idxs[key_sel].append(idx) + total_idxs = [] + for item in query_idxs: + try: + total_idxs.extend(item[:3]) + except: + total_idxs.extend(random.sample(list(range(search_upbound,len(trainset))),3)) + total_idxs = list(set(total_idxs)) + total_idxs = random.sample(total_idxs,50) + sub_set = trainset.select(total_idxs) + + features = [] + for idx,item in enumerate(sub_set): + query = model.encoder.get_query_vector(input_ids=torch.tensor([item['input_ids']]).long().to(device), + attention_mask=torch.tensor([item['attention_mask']]).long().to(device), + return_dict=True) + features.append(query.detach().cpu().numpy()) + + + + return sub_set,features + + + + + + def dataset_name_to_func(dataset_name): + mapping = { + 'squad': preprocess_sqaud_batch, + 'squad_v2': preprocess_sqaud_batch, + 'boolq': preprocess_boolq_batch, + 'narrativeqa': preprocess_narrativeqa_batch, + 'race': preprocess_race_batch, + 'newsqa': preprocess_newsqa_batch, + 'quoref': preprocess_sqaud_batch, + 'ropes': preprocess_ropes_batch, + 'drop': preprocess_drop_batch, + 'nqopen': preprocess_sqaud_abstractive_batch, + # 'multirc': preprocess_boolq_batch, + 'boolq_np': preprocess_boolq_batch, + 'openbookqa': preprocess_openbookqa_batch, + 'mctest': preprocess_race_batch, + 'social_iqa': preprocess_social_iqa_batch, + 'dream': preprocess_dream_batch, + } + return mapping[dataset_name] + + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + dataset_name_to_metric = { + 'squad': 'metric/squad_v1_local/squad_v1_local.py', + 'squad_v2': 'metric/squad_v2_local/squad_v2_local.py', + 'newsqa': 'metric/squad_v2_local/squad_v2_local.py', + 'boolq': 'metric/accuracy.py', + 'narrativeqa': 'metric/rouge_local/rouge_metric.py', + 'race': 'metric/accuracy.py', + 'quoref': 'metric/squad_v1_local/squad_v1_local.py', + 'ropes': 'metric/squad_v1_local/squad_v1_local.py', + 'drop': 'metric/squad_v1_local/squad_v1_local.py', + 'nqopen': 'metric/squad_v1_local/squad_v1_local.py', + 'boolq_np': 'metric/accuracy.py', + 'openbookqa': 'metric/accuracy.py', + 'mctest': 'metric/accuracy.py', + 'social_iqa': 'metric/accuracy.py', + 'dream': 'metric/accuracy.py', + } + + + + + + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + if data_args.source_prefix is None and model_args.model_name_or_path in [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + ]: + logger.warning( + "You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with " + "`--source_prefix 'translate English to German: ' `" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # tokenizer.add_tokens(['[TASK]', '[ABSTRACTIVE]','[QUESTION]','[CONTEXT]','[BOOL]','[EXTRACTIVE]','[MultiChoice]', + # '[OPTIONS]']) + tokens_to_add = ['[ABSTRACTIVE]', '[BOOL]', '[EXTRACTIVE]', '[MultiChoice]'] + special_tokens_dict = {'additional_special_tokens': ['[TASK]', '[QUESTION]', '[CONTEXT]', + '[OPTIONS]']} + + tokenizer.add_tokens(tokens_to_add) + num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) + added_tokens = tokenizer.get_added_vocab() + logger.info('Added tokens: {}'.format(added_tokens)) + + model = PromptT5.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + # task_num = data_args.max_task_num, + # prompt_num = data_args.prompt_number, + # format_num = data_args.qa_task_type_num, + # add_task_prompt = False + ) + + model.resize_token_embeddings(len(tokenizer)) + + + #reload format specific task-prompt for newly involved task +#format_prompts###task_promptsf + + data_args.reload_from_trained_prompt = False#@ + data_args.load_from_format_task_id = False#@ + ### before pretrain come !!!!!! + + if data_args.load_from_format_task_id and (data_args.dataset_name not in seed_datasets) and not data_args.reload_from_trained_prompt: + task_start_id = data_args.prompt_number * len(format2dataset.keys()) + task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number + format_task_id = task_start_id + task2id[dataset2format[data_args.dataset_name]] * data_args.prompt_number + model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = model.state_dict()['encoder.prompt_embeddings.weight'][format_task_id:format_task_id+data_args.prompt_number,:] + logger.info(f'Successfully initialize format {dataset2format[data_args.dataset_name]} task prompt for new task {data_args.dataset_name}, task id {task_id}') + # print(dataset2format[data_args.dataset_name]) + # print(data_args.dataset_name) + elif data_args.reload_from_trained_prompt: + assert data_args.trained_prompt_path,'Must specify the path of stored prompt' + prompt_info = torch.load(data_args.trained_prompt_path) + assert prompt_info['task2id'][data_args.dataset_name]==task2id[data_args.dataset_name],f'the task id in trained prompt task id is not matched to the current task id for {data_args.dataset_name}' + assert prompt_info['format2id'].keys()==format2id.keys(),'the format dont match' + task_start_id = data_args.prompt_number * len(format2dataset.keys()) + + task_id = task_start_id + task2id[data_args.dataset_name] * data_args.prompt_number + logger.info('task id range {} {}'.format(task_id,task_id+data_args.prompt_number)) + # assert torch.sum(model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] - prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:])==0 + model.state_dict()['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] = prompt_info['encoder.prompt_embeddings.weight'][task_id:task_id+data_args.prompt_number,:] + format_id = format2id[dataset2format[data_args.dataset_name]] + model.state_dict()['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] = prompt_info['encoder.prompt_embeddings.weight'][format_id*data_args.prompt_number:(format_id+1)*data_args.prompt_number, :] + logger.info( + f'Successfully restore task+format prompt for the task {data_args.dataset_name} from {data_args.trained_prompt_path}') + + # Set decoder_start_token_id + if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): + if isinstance(tokenizer, MBartTokenizer): + model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang] + else: + model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang) + + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + + + if training_args.local_rank == -1 or training_args.no_cuda: + device = torch.device("cuda") + n_gpu = torch.cuda.device_count() + + + + + + # Temporarily set max_target_length for training. + max_target_length = data_args.max_target_length + padding = "max_length" if data_args.pad_to_max_length else False + + if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): + logger.warning( + "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" + f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" + ) + + + question_column = data_args.question_column + context_column = data_args.context_column + answer_column = data_args.answer_column + # import random + if data_args.max_source_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_source_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + data_args.max_source_length = min(data_args.max_source_length, tokenizer.model_max_length) + # Data collator + label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + if data_args.pad_to_max_length: + data_collator = default_data_collator + else: + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=8 if training_args.fp16 else None, + ) + + #start + train_dataloaders = {} + eval_dataloaders = {} + replay_dataloaders = {} + all_replay = None + + for ds_name in ["squad","wikisql","sst","srl","woz.en"]: + + if True: + cur_dataset = ds_name + train_dataloaders[ds_name] = load_from_disk("./oursdecanotask/{}-train.hf".format(cur_dataset)).select(range(200)) + eval_dataloaders[ds_name] = (load_from_disk("./oursdecanotask/{}-eval.hf".format(cur_dataset)).select(range(200)),load_from_disk("./oursdecanotask/{}-evalex.hf".format(cur_dataset)).select(range(200))) + + for ds_name in ["cnn_dailymail","multinli.in.out","zre"]: + eval_dataloaders[ds_name] = (load_from_disk("./oursdecanotask/{}-eval.hf".format(ds_name)),load_from_disk("./oursdecanotask/{}-evalex.hf".format(ds_name))) + + pre_tasks = [] + pre_general = [] + pre_test = [] + max_length = ( + training_args.generation_max_length + if training_args.generation_max_length is not None + else data_args.val_max_target_length + ) + num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams + + task_sequence = ["squad","wikisql","sst","srl","woz.en"] + # task_sequence = ["woz.en","srl","sst","wikisql","squad"] + need_to_do_dss = [] + fileout = open("diana_log.txt",'w') + all_replay = None + all_features = [] + all_ids = [] + cluster_num=0 + for cur_dataset in task_sequence: + cluster_num+=5 + pre_tasks.append(cur_dataset) + if cur_dataset==task_sequence[-1]: + pre_tasks.extend(["cnn_dailymail","multinli.in.out","zre"]) + data_args.dataset_name = cur_dataset + logger.info("current_dataset:"+cur_dataset) + + + training_args.do_train = True + training_args.to_eval = False + metric = load_metric("metric/squad_v1_local/squad_v1_local.py") + if all_replay is not None: + fused = datasets.concatenate_datasets([all_replay,train_dataloaders[cur_dataset]]) + else: + fused = train_dataloaders[cur_dataset] + training_args.num_train_epochs = 5 + model.encoder.reset_train_count() + + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=fused, + eval_dataset=None, + eval_examples=None, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + train_result = trainer.train() + + if training_args.local_rank<=0: + save_set,features = save_load_diverse_sample(model,train_dataloaders[cur_dataset]) + + if all_replay is None: + all_replay = save_set + else: + all_replay = datasets.concatenate_datasets([all_replay,save_set]) + + if all_features==[]: + all_features=features + else: + all_features.extend(features) + np.save("./all_features.npy",np.array(all_features)) + + all_replay.save_to_disk("all_replay@{}.hf".format(cur_dataset)) + + + + if training_args.local_rank!=-1: + torch.distributed.barrier() + all_replay = load_from_disk("all_replay@{}.hf".format(cur_dataset)) + all_ids.extend([task2id[cur_dataset]]*50) + all_features=np.load("./all_features.npy").tolist() + + model.encoder.add_negs(all_ids,all_features) + + + + + + + + + for pre_dataset in pre_tasks: + + + data_args.dataset_name = pre_dataset + + metric = load_metric("metric/squad_v1_local/squad_v1_local.py") + eval_dataset,eval_examples = eval_dataloaders[pre_dataset] + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=None, + eval_dataset=eval_dataset, + eval_examples=eval_examples, + answer_column_name=answer_column, + dataset_name=data_args.dataset_name, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + + + + torch.cuda.empty_cache() + logger.info("*** Evaluate:{} ***".format(data_args.dataset_name)) + + max_length, num_beams, ignore_keys_for_eval = None, None, None + metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, ignore_keys=ignore_keys_for_eval,metric_key_prefix="eval") + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + if training_args.local_rank<=0: + try: + print("after_train_",cur_dataset,"_test_",pre_dataset,file=fileout) + print(metrics,file=fileout) + except: + pass + + + + + + + languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None] + if len(languages) > 0: + kwargs["language"] = languages + return None + + +# - + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/diana/downstreamdeca/dianawotask.sh b/diana/downstreamdeca/dianawotask.sh new file mode 100644 index 0000000..9ce8915 --- /dev/null +++ b/diana/downstreamdeca/dianawotask.sh @@ -0,0 +1,53 @@ +function fulldata_hfdata() { +SAVING_PATH=$1 +PRETRAIN_MODEL_PATH=$2 +TRAIN_BATCH_SIZE=$3 +EVAL_BATCH_SIZE=$4 + + + + + +mkdir -p ${SAVING_PATH} + +python -m torch.distributed.launch --nproc_per_node=4 dianawotask.py \ + --model_name_or_path ${PRETRAIN_MODEL_PATH} \ + --output_dir ${SAVING_PATH} \ + --dataset_name squad \ + --do_train \ + --do_eval \ + --per_device_train_batch_size ${TRAIN_BATCH_SIZE} \ + --per_device_eval_batch_size ${EVAL_BATCH_SIZE} \ + --overwrite_output_dir \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 5 \ + --warmup_ratio 0.1 \ + --logging_steps 5 \ + --learning_rate 1e-4 \ + --predict_with_generate \ + --num_beams 4 \ + --save_strategy no \ + --evaluation_strategy no \ + --weight_decay 1e-2 \ + --max_source_length 512 \ + --label_smoothing_factor 0.1 \ + --do_lowercase True \ + --load_best_model_at_end True \ + --greater_is_better True \ + --save_total_limit 10 \ + --ddp_find_unused_parameters True 2>&1 | tee ${SAVING_PATH}/log +} + +SAVING_PATH=$1 +TRAIN_BATCH_SIZE=$2 +EVAL_BATCH_SIZE=$3 + + + + + +MODE=try + + +SAVING_PATH=${SAVING_PATH}/lifelong/${MODE} +fulldata_hfdata ${SAVING_PATH} t5-base ${TRAIN_BATCH_SIZE} ${EVAL_BATCH_SIZE} diff --git a/diana/downstreamdeca/extract.py b/diana/downstreamdeca/extract.py new file mode 100644 index 0000000..a4a558b --- /dev/null +++ b/diana/downstreamdeca/extract.py @@ -0,0 +1,169 @@ +import itertools +import json +import os +import csv +import errno +import random +from random import shuffle +from typing import List +import codecs +import nltk +import glob +import xml.etree.ElementTree as ET +from datasets import load_dataset + +#nltk.download("stopwords") +from nltk.corpus import stopwords + + + + + +def preprocess_function(examples): + preprocess_fn = preprocess_all#dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, "input","output") + model_inputs = tokenizer(inputs, max_length=1024, padding=padding, truncation=True) + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=128, padding=padding, truncation=True) + + model_inputs["labels"] = labels["input_ids"] + + + return model_inputs + +def preprocess_function_test(examples): + preprocess_fn = preprocess_all#dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, "input","output") + model_inputs = tokenizer(inputs, max_length=1024, padding=padding, truncation=True) + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=128, padding=padding, truncation=True) + model_inputs["example_id"] = [] + model_inputs["id"] = [] + for i in range(len(model_inputs["input_ids"])): + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = i #sample_mapping[i] + model_inputs["example_id"].append(i) + model_inputs["id"].append(i) + + model_inputs["labels"] = labels["input_ids"] + + + return model_inputs + +def add_id(example,index): + example.update({'id':index}) + return example + +def prep(raw_ds,fname): + ds = [] + dss = map(lambda x: x["paragraphs"], raw_ds["data"]) + for dd in dss: + ds.extend(dd) + print(len(ds)) + print(len(raw_ds["data"])) + + examples = {"context":[],"question":[],"answer":[]} + fout = open("{}-train.jsonl".format(fname),'w') + for d in ds: + + context = d["context"] + #TOKENIZER.encode(d["context"]) + for qa in d["qas"]: + question = qa["question"]#TOKENIZER.encode(qa["question"]) + + raw_answers = qa["answers"] + if len(raw_answers) == 0: + assert qa["is_impossible"] + raw_answers.append({"text": ""}) + + answers = [] + + for i, raw_answer in enumerate(raw_answers): + answers.append(raw_answer["text"]) + jsonline = json.dumps({"question":question,"context":context,"answer":answers}) + print(jsonline,file=fout) + fout.close() + + + + + + +def prep_test(raw_ds,use_answers,fname): + ds = [] + dss = map(lambda x: x["paragraphs"], raw_ds["data"]) + for dd in dss: + ds.extend(dd) + print(len(ds)) + print(len(raw_ds["data"])) + fout = open("{}-eval.jsonl".format(fname),'w') + idx = 0 + f_answers = [] + use_answers = None + all_json_lines = [] + for d in ds: + + context = d["context"] + #TOKENIZER.encode(d["context"]) + for qa in d["qas"]: + question = qa["question"]#TOKENIZER.encode(qa["question"]) + raw_answers = qa["answers"] + f_answers.extend([_["text"] for _ in qa["answers"]]) + if True: + if len(raw_answers) == 0: + assert qa["is_impossible"] + raw_answers.append({"text": ""}) + + answers = [] + + for i, raw_answer in enumerate(raw_answers): + answers.append(raw_answer["text"]) + all_json_lines.append({"question":question,"context":context,"answer":answers,"preid":qa["id"]}) + if fname in ["wikisql","woz.en","multinli.in.out"]: + all_json_lines.sort(key=lambda x: x["preid"]) + + for item in all_json_lines: + jsonline = json.dumps(item) + print(jsonline,file=fout) + + fout.close() + + +from collections import Counter +import json + + +def main(): + for item in ["wikisql","squad","srl","sst","woz.en","iwslt.en.de","cnn_dailymail","multinli.in.out","zre"]: + print("current_dataset:",item) + print("Loading......") + + testfile = open(item+"_to_squad-test-v2.0.json",'r') + if not item in ["cnn_dailymail","multinli.in.out","zre"]: + trainfile = open(item+"_to_squad-train-v2.0.json",'r') + ftrain = json.load(trainfile) + ftest = json.load(testfile) + faddition = None + if item in ["woz.en","wikisql"]: + addition_answers = open(item+"_answers.json",'r') + faddition = json.load(addition_answers) + if item=="woz.en": + faddition = [[_[1]] for _ in faddition] + else: + faddition = [[_["answer"]] for _ in faddition] + else: + faddition = None + print("Loaded.") + if not item in ["cnn_dailymail","multinli.in.out","zre"]: + prep(ftrain,item) + + prep_test(ftest,faddition,item) + + + + +main() \ No newline at end of file diff --git a/diana/downstreamdeca/l2ptrainer.py b/diana/downstreamdeca/l2ptrainer.py new file mode 100755 index 0000000..da6c09d --- /dev/null +++ b/diana/downstreamdeca/l2ptrainer.py @@ -0,0 +1,557 @@ +from transformers import Seq2SeqTrainer, is_torch_tpu_available, EvalPrediction +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +import nltk +import datasets +import re +import os +import numpy as np +import torch +import random +from pathlib import Path +import nltk +# nltk.download('punkt') +from transformers.trainer_utils import ( + PREFIX_CHECKPOINT_DIR, + BestRun, + EvalLoopOutput, + EvalPrediction, + HPSearchBackend, + HubStrategy, + IntervalStrategy, + PredictionOutput, + ShardedDDPOption, + TrainerMemoryTracker, + TrainOutput, + default_compute_objective, + default_hp_space, + denumpify_detensorize, + get_last_checkpoint, + number_of_arguments, + set_seed, + speed_metrics, +) +import warnings +from transformers.trainer_pt_utils import ( + DistributedLengthGroupedSampler, + DistributedSamplerWithLoop, + DistributedTensorGatherer, + IterableDatasetShard, + LabelSmoother, + LengthGroupedSampler, + SequentialDistributedSampler, + ShardSampler, + distributed_broadcast_scalars, + distributed_concat, + find_batch_size, + get_parameter_names, + nested_concat, + nested_detach, + nested_numpify, + nested_truncate, + nested_xla_mesh_reduce, + reissue_pt_warnings, +) +from transformers.file_utils import ( + CONFIG_NAME, + WEIGHTS_NAME, + get_full_repo_name, + is_apex_available, + is_datasets_available, + is_in_notebook, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, + is_torch_tpu_available, +) +TRAINING_ARGS_NAME = "training_args.bin" +TRAINER_STATE_NAME = "trainer_state.json" +OPTIMIZER_NAME = "optimizer.pt" +SCHEDULER_NAME = "scheduler.pt" +SCALER_NAME = "scaler.pt" +from transformers.trainer_utils import PredictionOutput,EvalLoopOutput +if is_torch_tpu_available(): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + +def fix_buggy_characters(str): + return re.sub("[{}^\\\\`\u2047<]", " ", str) +def replace_punctuation(str): + return str.replace("\"", "").replace("'", "") +def score_string_similarity(str1, str2): + if str1 == str2: + return 3.0 # Better than perfect token match + str1 = fix_buggy_characters(replace_punctuation(str1)) + str2 = fix_buggy_characters(replace_punctuation(str2)) + if str1 == str2: + return 2.0 + if " " in str1 or " " in str2: + str1_split = str1.split(" ") + str2_split = str2.split(" ") + overlap = list(set(str1_split) & set(str2_split)) + return len(overlap) / max(len(str1_split), len(str2_split)) + else: + if str1 == str2: + return 1.0 + else: + return 0.0 + +class QuestionAnsweringTrainer(Seq2SeqTrainer): + def __init__(self, *args, tokenizer,eval_examples=None,answer_column_name='answers',dataset_name='squad', **kwargs): + super().__init__(*args, **kwargs) + self.eval_examples = eval_examples + self.answer_column_name = answer_column_name + self.dataset_name = dataset_name + self.tokenizer = tokenizer + if "cnn" in dataset_name: + self.post_process_function = self._post_process_narrative_qa + else: + self.post_process_function = self._post_process_squad + + + def _post_process_squad( + self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval" + ,version_2_with_negative=False): + # Decode the predicted tokens. + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + + # Build a map example to its corresponding features. + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)} + predictions = {} + # Let's loop over all the examples! + for example_index, example in enumerate(examples): + # This is the index of the feature associated to the current example. + try: + feature_index = feature_per_example[example_index] + except: + print(feature_per_example.keys()) + print(example_index) + assert False + predictions[example["id"]] = decoded_preds[feature_index] + + # Format the result to the format the metric expects. + if version_2_with_negative: + formatted_predictions = [ + {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() + ] + else: + formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] + + references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples] + return EvalPrediction(predictions=formatted_predictions, label_ids=references) + + def _post_process_squad_v2( + self, examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval"): + # Decode the predicted tokens. + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + + # Build a map example to its corresponding features. + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)} + predictions = {} + # Let's loop over all the examples! + for example_index, example in enumerate(examples): + # This is the index of the feature associated to the current example. + feature_index = feature_per_example[example_index] + predictions[example["id"]] = decoded_preds[feature_index] + + # Format the result to the format the metric expects. + formatted_predictions = [ + {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() + ] + + references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples] + return EvalPrediction(predictions=formatted_predictions, label_ids=references) + + @classmethod + def postprocess_text(cls,preds, labels): + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + return preds, labels + + def _post_process_boolq(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"): + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + # preds = [" ".join(pred) for pred in decoded_preds] + preds = decoded_preds + outputs = [] + # for pred in preds: + # if 'True' == pred: + # outputs.append(1) + # elif 'False' == pred: + # outputs.append(0) + # else: + # outputs.append(0) + # references = [1 if ex['answer'] else 0 for ex in examples] + + references = [] + for pred, ref in zip(preds, examples): + ans = ref['answer'] + references.append(1 if ans else 0) + if pred.lower() in ['true', 'false']: + outputs.append(1 if pred.lower() == 'true' else 0) + else: + # generate a wrong prediction if pred is not true/false + outputs.append(0 if ans else 1) + + assert(len(references)==len(outputs)) + formatted_predictions = [] + return EvalPrediction(predictions=outputs, label_ids=references) + + def _post_process_narrative_qa(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"): + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + # preds = [" ".join(pred) for pred in decoded_preds] + preds = decoded_preds + references = [exp['answers']['text'][0] for exp in examples] + formatted_predictions = [] + preds, references = self.postprocess_text(preds,references) + assert(len(preds)==len(references)) + return EvalPrediction(predictions=preds, label_ids=references) + + def compute_loss(self, model, inputs, return_outputs=False): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + Subclass and override for custom behavior. + """ + if self.label_smoother is not None and "labels" in inputs: + labels = inputs.pop("labels") + inputs["train"]=True + else: + labels = None + outputs,sim_loss = model(**inputs) + + # print(model.encoder.) + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if labels is not None: + loss = self.label_smoother(outputs, labels) + else: + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + + loss += sim_loss*0.1 + return (loss, outputs) if return_outputs else loss + + + def _post_process_drop(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"): + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + # preds = [" ".join(pred) for pred in decoded_preds] + preds = decoded_preds + references = [exp['answers'][0]['text'] for exp in examples] + formatted_predictions = [] + preds, references = self.postprocess_text(preds, references) + assert (len(preds) == len(references)) + return EvalPrediction(predictions=preds, label_ids=references) + + def _post_process_race(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"): + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + references = [{"answer": ex['answer'], 'options':ex['options']} for ex in examples] + assert(len(references)==len(decoded_preds)) + gold_ids , pred_ids = [],[] + for prediction, reference in zip(decoded_preds, references): + # reference = json.loads(reference) + gold = int(ord(reference['answer'].strip()) - ord('A')) + options = reference['options'] + prediction = prediction.replace("\n", "").strip() + options = [opt.strip() for opt in options if len(opt) > 0] + # print('options',options,type(options)) + # print('prediction',prediction) + # print('answer',gold) + scores = [score_string_similarity(opt, prediction) for opt in options] + max_idx = np.argmax(scores) + gold_ids.append(gold) + pred_ids.append(max_idx) + # selected_ans = chr(ord('A') + max_idx) + # print(len(references),len(decoded_preds)) + + return EvalPrediction(predictions=pred_ids,label_ids = gold_ids) + + def _post_process_openbookqa(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"): + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + references = [{"answer": ex['answerKey'], 'options': ex['choices']['text']} for ex in examples] + assert len(references) == len(decoded_preds) + gold_ids, pred_ids = [], [] + for prediction, reference in zip(decoded_preds, references): + # reference = json.loads(reference) + gold = int(ord(reference['answer'].strip()) - ord('A')) + options = reference['options'] + prediction = prediction.replace("\n", "").strip() + options = [opt.strip() for opt in options if len(opt) > 0] + scores = [score_string_similarity(opt, prediction) for opt in options] + max_idx = np.argmax(scores) + gold_ids.append(gold) + pred_ids.append(max_idx) + return EvalPrediction(predictions=pred_ids,label_ids = gold_ids) + + def _post_process_dream(self,examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, tokenizer, stage="eval"): + preds = outputs.predictions + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + references = [{"answer": ex['answer'], 'options': ex['choice']} for ex in examples] + assert(len(references)==len(decoded_preds)) + gold_ids, pred_ids = [],[] + for prediction, reference in zip(decoded_preds, references): + gold = reference['options'].index(reference['answer']) + options = reference['options'] + prediction = prediction.replace("\n", "").strip() + options = [opt.strip() for opt in options if len(opt) > 0] + scores = [score_string_similarity(opt, prediction) for opt in options] + max_idx = np.argmax(scores) + gold_ids.append(gold) + pred_ids.append(max_idx) + return EvalPrediction(predictions=pred_ids, label_ids = gold_ids) + + def evaluate(self, eval_dataset=None, eval_examples=None, tokenizer=None,ignore_keys=None, metric_key_prefix: str = "eval", + max_length: Optional[int] = None,num_beams: Optional[int] = None): + gen_kwargs = {} + gen_kwargs["max_length"] = ( + max_length + ) + gen_kwargs["num_beams"] = ( + num_beams + ) + self._gen_kwargs = gen_kwargs + self._memory_tracker.start() + self._max_length = max_length if max_length is not None else self.args.generation_max_length + self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams + eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset + eval_dataloader = self.get_eval_dataloader(eval_dataset) + eval_examples = self.eval_examples if eval_examples is None else eval_examples + + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + self.compute_metrics = None + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + try: + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + ) + finally: + self.compute_metrics = compute_metrics + + if self.post_process_function is not None and self.compute_metrics is not None: + eval_preds = self.post_process_function(eval_examples, eval_dataset, output,tokenizer) + metrics = self.compute_metrics(eval_preds, + tri=self.dataset_name=="multinli.in.out", + dialogue=self.dataset_name=='woz.en', + rouge=self.dataset_name=='cnn_dailymail', + logical_form=self.dataset_name=='wikisql', + corpus_f1=self.dataset_name=='zre') + if self.dataset_name=='narrativeqa': + metrics = {key: value.mid.fmeasure * 100 for key, value in metrics.items()} + metrics = {k: round(v, 4) for k, v in metrics.items()} + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + print(metrics) + else: + metrics = {} + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) + self._memory_tracker.stop_and_update_metrics(metrics) + return metrics + + + def _save_checkpoint(self, model, trial, metrics=None): + # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we + # want to save except FullyShardedDDP. + # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" + + # Save model checkpoint + save_only_best = True + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + if self.hp_search_backend is not None and trial is not None: + if self.hp_search_backend == HPSearchBackend.OPTUNA: + run_id = trial.number + elif self.hp_search_backend == HPSearchBackend.RAY: + from ray import tune + + run_id = tune.get_trial_id() + elif self.hp_search_backend == HPSearchBackend.SIGOPT: + run_id = trial.id + run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" + run_dir = os.path.join(self.args.output_dir, run_name) + else: + run_dir = self.args.output_dir + self.store_flos() + + output_dir = os.path.join(run_dir, checkpoint_folder) + if not save_only_best: + self.save_model(output_dir) + if self.deepspeed: + # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed + # config `stage3_gather_fp16_weights_on_model_save` is True + self.deepspeed.save_checkpoint(output_dir) + + # Save optimizer and scheduler + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + self.optimizer.consolidate_state_dict() + + if is_torch_tpu_available(): + xm.rendezvous("saving_optimizer_states") + xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + elif is_sagemaker_mp_enabled(): + if smp.dp_rank() == 0: + # Consolidate the state dict on all processed of dp_rank 0 + opt_state_dict = self.optimizer.state_dict() + # Save it and the scheduler on the main process + if self.args.should_save and not save_only_best: + torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling: + torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + elif self.args.should_save and not self.deepspeed and not save_only_best: + # deepspeed.save_checkpoint above saves model/optim/sched + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + self.do_grad_scaling=True + # if self.do_grad_scaling: + # torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + + # Determine the new best metric / best model checkpoint + if metrics is not None and self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics[metric_to_check] + + operator = np.greater if self.args.greater_is_better else np.less + if ( + self.state.best_metric is None + or self.state.best_model_checkpoint is None + or operator(metric_value, self.state.best_metric) + ): + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + best_dir = os.path.join(run_dir,'best-checkpoint') + # if not os.path.exists(best_dir): + # os.mkdir(best_dir,exist_ok=True) + if self.args.should_save: + self.save_model(best_dir) + self.state.save_to_json(os.path.join(best_dir, TRAINER_STATE_NAME)) + + + # Save the Trainer state + if self.args.should_save and not save_only_best: + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + + # Save RNG state in non-distributed training + rng_states = { + "python": random.getstate(), + "numpy": np.random.get_state(), + "cpu": torch.random.get_rng_state(), + } + if torch.cuda.is_available(): + if self.args.local_rank == -1: + # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) + rng_states["cuda"] = torch.cuda.random.get_rng_state_all() + else: + rng_states["cuda"] = torch.cuda.random.get_rng_state() + + if is_torch_tpu_available(): + rng_states["xla"] = xm.get_rng_state() + + # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may + # not yet exist. + os.makedirs(output_dir, exist_ok=True) + local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank + if local_rank == -1: + torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + else: + torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth")) + + if self.args.push_to_hub: + self._push_from_checkpoint(output_dir) + + # Maybe delete some older checkpoints. + if self.args.should_save and not save_only_best: + self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + + + + + # def predict(self, predict_dataset=None, predict_examples=None, tokenizer=None,ignore_keys=None, metric_key_prefix: str = "predict", + # max_length: Optional[int] = None,num_beams: Optional[int] = None): + # self._memory_tracker.start() + # self._max_length = max_length if max_length is not None else self.args.generation_max_length + # self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams + # predict_dataset = self.predict_dataset if predict_dataset is None else predict_dataset + # predict_dataloader = self.get_eval_dataloader(predict_dataset) + # predict_examples = predict_examples + # + # # Temporarily disable metric computation, we will do it in the loop here. + # compute_metrics = self.compute_metrics + # self.compute_metrics = None + # eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + # try: + # output = eval_loop( + # predict_dataloader, + # description="Predict", + # # No point gathering the predictions if there are no metrics, otherwise we defer to + # # self.args.prediction_loss_only + # prediction_loss_only=True if compute_metrics is None else None, + # ignore_keys=ignore_keys, + # ) + # finally: + # self.compute_metrics = compute_metrics + # + # if self.post_process_function is not None and self.compute_metrics is not None: + # predict_preds = self.post_process_function(predict_examples, predict_dataset, output,tokenizer) + # metrics = self.compute_metrics(predict_preds) + # if self.dataset_name=='narrativeqa': + # metrics = {key: value.mid.fmeasure * 100 for key, value in metrics.items()} + # metrics = {k: round(v, 4) for k, v in metrics.items()} + # # Prefix all keys with metric_key_prefix + '_' + # for key in list(metrics.keys()): + # if not key.startswith(f"{metric_key_prefix}_"): + # metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + # + # self.log(metrics) + # else: + # metrics = {} + # + # self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) + # self._memory_tracker.stop_and_update_metrics(metrics) + # return metrics diff --git a/diana/downstreamdeca/metric/accuracy.py b/diana/downstreamdeca/metric/accuracy.py new file mode 100644 index 0000000..b9b0e33 --- /dev/null +++ b/diana/downstreamdeca/metric/accuracy.py @@ -0,0 +1,105 @@ +# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Accuracy metric.""" + +from sklearn.metrics import accuracy_score + +import datasets + + +_DESCRIPTION = """ +Accuracy is the proportion of correct predictions among the total number of cases processed. It can be computed with: +Accuracy = (TP + TN) / (TP + TN + FP + FN) + Where: +TP: True positive +TN: True negative +FP: False positive +FN: False negative +""" + + +_KWARGS_DESCRIPTION = """ +Args: + predictions (`list` of `int`): Predicted labels. + references (`list` of `int`): Ground truth labels. + normalize (`boolean`): If set to False, returns the number of correctly classified samples. Otherwise, returns the fraction of correctly classified samples. Defaults to True. + sample_weight (`list` of `float`): Sample weights Defaults to None. + +Returns: + accuracy (`float` or `int`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher score means higher accuracy. + +Examples: + + Example 1-A simple example + >>> accuracy_metric = datasets.load_metric("accuracy") + >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0]) + >>> print(results) + {'accuracy': 0.5} + + Example 2-The same as Example 1, except with `normalize` set to `False`. + >>> accuracy_metric = datasets.load_metric("accuracy") + >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], normalize=False) + >>> print(results) + {'accuracy': 3.0} + + Example 3-The same as Example 1, except with `sample_weight` set. + >>> accuracy_metric = datasets.load_metric("accuracy") + >>> results = accuracy_metric.compute(references=[0, 1, 2, 0, 1, 2], predictions=[0, 1, 1, 2, 1, 0], sample_weight=[0.5, 2, 0.7, 0.5, 9, 0.4]) + >>> print(results) + {'accuracy': 0.8778625954198473} +""" + + +_CITATION = """ +@article{scikit-learn, + title={Scikit-learn: Machine Learning in {P}ython}, + author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. + and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. + and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and + Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.}, + journal={Journal of Machine Learning Research}, + volume={12}, + pages={2825--2830}, + year={2011} +} +""" + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class Accuracy(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": datasets.Sequence(datasets.Value("int32")), + "references": datasets.Sequence(datasets.Value("int32")), + } + if self.config_name == "multilabel" + else { + "predictions": datasets.Value("int32"), + "references": datasets.Value("int32"), + } + ), + reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"], + ) + + def _compute(self, predictions, references, normalize=True, sample_weight=None): + return { + "accuracy": float( + accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight) + ) + } diff --git a/diana/downstreamdeca/metric/bleu.py b/diana/downstreamdeca/metric/bleu.py new file mode 100644 index 0000000..545cb1e --- /dev/null +++ b/diana/downstreamdeca/metric/bleu.py @@ -0,0 +1,126 @@ +# Copyright 2020 The HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" BLEU metric. """ + +import datasets + +from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py + + +_CITATION = """\ +@INPROCEEDINGS{Papineni02bleu:a, + author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu}, + title = {BLEU: a Method for Automatic Evaluation of Machine Translation}, + booktitle = {}, + year = {2002}, + pages = {311--318} +} +@inproceedings{lin-och-2004-orange, + title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation", + author = "Lin, Chin-Yew and + Och, Franz Josef", + booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics", + month = "aug 23{--}aug 27", + year = "2004", + address = "Geneva, Switzerland", + publisher = "COLING", + url = "https://www.aclweb.org/anthology/C04-1072", + pages = "501--507", +} +""" + +_DESCRIPTION = """\ +BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another. +Quality is considered to be the correspondence between a machine's output and that of a human: "the closer a machine translation is to a professional human translation, +the better it is" – this is the central idea behind BLEU. BLEU was one of the first metrics to claim a high correlation with human judgements of quality, and +remains one of the most popular automated and inexpensive metrics. + +Scores are calculated for individual translated segments—generally sentences—by comparing them with a set of good quality reference translations. +Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. Intelligibility or grammatical correctness +are not taken into account[citation needed]. + +BLEU's output is always a number between 0 and 1. This value indicates how similar the candidate text is to the reference texts, with values closer to 1 +representing more similar texts. Few human translations will attain a score of 1, since this would indicate that the candidate is identical to one of the +reference translations. For this reason, it is not necessary to attain a score of 1. Because there are more opportunities to match, adding additional +reference translations will increase the BLEU score. +""" + +_KWARGS_DESCRIPTION = """ +Computes BLEU score of translated segments against one or more references. +Args: + predictions: list of translations to score. + Each translation should be tokenized into a list of tokens. + references: list of lists of references for each translation. + Each reference should be tokenized into a list of tokens. + max_order: Maximum n-gram order to use when computing BLEU score. + smooth: Whether or not to apply Lin et al. 2004 smoothing. +Returns: + 'bleu': bleu score, + 'precisions': geometric mean of n-gram precisions, + 'brevity_penalty': brevity penalty, + 'length_ratio': ratio of lengths, + 'translation_length': translation_length, + 'reference_length': reference_length +Examples: + + >>> predictions = [ + ... ["hello", "there", "general", "kenobi"], # tokenized prediction of the first sample + ... ["foo", "bar", "foobar"] # tokenized prediction of the second sample + ... ] + >>> references = [ + ... [["hello", "there", "general", "kenobi"], ["hello", "there", "!"]], # tokenized references for the first sample (2 references) + ... [["foo", "bar", "foobar"]] # tokenized references for the second sample (1 reference) + ... ] + >>> bleu = datasets.load_metric("bleu") + >>> results = bleu.compute(predictions=predictions, references=references) + >>> print(results["bleu"]) + 1.0 +""" + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class Bleu(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), + "references": datasets.Sequence( + datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), id="references" + ), + } + ), + codebase_urls=["https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py"], + reference_urls=[ + "https://en.wikipedia.org/wiki/BLEU", + "https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213", + ], + ) + + def _compute(self, predictions, references, max_order=4, smooth=False): + score = compute_bleu( + reference_corpus=references, translation_corpus=predictions, max_order=max_order, smooth=smooth + ) + (bleu, precisions, bp, ratio, translation_length, reference_length) = score + return { + "bleu": bleu, + "precisions": precisions, + "brevity_penalty": bp, + "length_ratio": ratio, + "translation_length": translation_length, + "reference_length": reference_length, + } diff --git a/diana/downstreamdeca/metric/bleu_local/bleu.py b/diana/downstreamdeca/metric/bleu_local/bleu.py new file mode 100644 index 0000000..03eda55 --- /dev/null +++ b/diana/downstreamdeca/metric/bleu_local/bleu.py @@ -0,0 +1,219 @@ +# Copyright 2020 The HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" BLEU metric. """ + +import datasets + +#from .nmt_bleu import compute_bleu # From: https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py + + +_CITATION = """\ +@INPROCEEDINGS{Papineni02bleu:a, + author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu}, + title = {BLEU: a Method for Automatic Evaluation of Machine Translation}, + booktitle = {}, + year = {2002}, + pages = {311--318} +} +@inproceedings{lin-och-2004-orange, + title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation", + author = "Lin, Chin-Yew and + Och, Franz Josef", + booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics", + month = "aug 23{--}aug 27", + year = "2004", + address = "Geneva, Switzerland", + publisher = "COLING", + url = "https://www.aclweb.org/anthology/C04-1072", + pages = "501--507", +} +""" + +_DESCRIPTION = """\ +BLEU (bilingual evaluation understudy) is an algorithm for evaluating the quality of text which has been machine-translated from one natural language to another. +Quality is considered to be the correspondence between a machine's output and that of a human: "the closer a machine translation is to a professional human translation, +the better it is" – this is the central idea behind BLEU. BLEU was one of the first metrics to claim a high correlation with human judgements of quality, and +remains one of the most popular automated and inexpensive metrics. + +Scores are calculated for individual translated segments—generally sentences—by comparing them with a set of good quality reference translations. +Those scores are then averaged over the whole corpus to reach an estimate of the translation's overall quality. Intelligibility or grammatical correctness +are not taken into account[citation needed]. + +BLEU's output is always a number between 0 and 1. This value indicates how similar the candidate text is to the reference texts, with values closer to 1 +representing more similar texts. Few human translations will attain a score of 1, since this would indicate that the candidate is identical to one of the +reference translations. For this reason, it is not necessary to attain a score of 1. Because there are more opportunities to match, adding additional +reference translations will increase the BLEU score. +""" + +_KWARGS_DESCRIPTION = """ +Computes BLEU score of translated segments against one or more references. +Args: + predictions: list of translations to score. + Each translation should be tokenized into a list of tokens. + references: list of lists of references for each translation. + Each reference should be tokenized into a list of tokens. + max_order: Maximum n-gram order to use when computing BLEU score. + smooth: Whether or not to apply Lin et al. 2004 smoothing. +Returns: + 'bleu': bleu score, + 'precisions': geometric mean of n-gram precisions, + 'brevity_penalty': brevity penalty, + 'length_ratio': ratio of lengths, + 'translation_length': translation_length, + 'reference_length': reference_length +Examples: + + >>> predictions = [ + ... ["hello", "there", "general", "kenobi"], # tokenized prediction of the first sample + ... ["foo", "bar", "foobar"] # tokenized prediction of the second sample + ... ] + >>> references = [ + ... [["hello", "there", "general", "kenobi"], ["hello", "there", "!"]], # tokenized references for the first sample (2 references) + ... [["foo", "bar", "foobar"]] # tokenized references for the second sample (1 reference) + ... ] + >>> bleu = datasets.load_metric("bleu") + >>> results = bleu.compute(predictions=predictions, references=references) + >>> print(results["bleu"]) + 1.0 +""" + +import collections +import math + + +def _get_ngrams(segment, max_order): + """Extracts all n-grams upto a given maximum order from an input segment. + Args: + segment: text segment from which n-grams will be extracted. + max_order: maximum length in tokens of the n-grams returned by this + methods. + Returns: + The Counter containing all n-grams upto max_order in segment + with a count of how many times each n-gram occurred. + """ + ngram_counts = collections.Counter() + for order in range(1, max_order + 1): + for i in range(0, len(segment) - order + 1): + ngram = tuple(segment[i:i+order]) + ngram_counts[ngram] += 1 + return ngram_counts + + +def compute_bleu(reference_corpus, translation_corpus, max_order=1, + smooth=False): + """Computes BLEU score of translated segments against one or more references. + Args: + reference_corpus: list of lists of references for each translation. Each + reference should be tokenized into a list of tokens. + translation_corpus: list of translations to score. Each translation + should be tokenized into a list of tokens. + max_order: Maximum n-gram order to use when computing BLEU score. + smooth: Whether or not to apply Lin et al. 2004 smoothing. + Returns: + 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram + precisions and brevity penalty. + """ + matches_by_order = [0] * max_order + possible_matches_by_order = [0] * max_order + reference_length = 0 + translation_length = 0 + for (references, translation) in zip(reference_corpus, + translation_corpus): + # translation = references[0] + translation = [item.lower() for item in translation] + + reference_length += min(len(r) for r in references) + translation_length += len(translation) + + merged_ref_ngram_counts = collections.Counter() + + for reference in references: + merged_ref_ngram_counts |= _get_ngrams(reference, max_order) + # print(merged_ref_ngram_counts) + translation_ngram_counts = _get_ngrams(translation, max_order) + overlap = translation_ngram_counts & merged_ref_ngram_counts + for ngram in overlap: + matches_by_order[len(ngram)-1] += overlap[ngram] + for order in range(1, max_order+1): + possible_matches = len(translation) - order + 1 + if possible_matches > 0: + possible_matches_by_order[order-1] += possible_matches + + + precisions = [0] * max_order + for i in range(0, max_order): + if smooth: + precisions[i] = ((matches_by_order[i] + 1.) / + (possible_matches_by_order[i] + 1.)) + else: + if possible_matches_by_order[i] > 0: + precisions[i] = (float(matches_by_order[i]) / + possible_matches_by_order[i]) + else: + precisions[i] = 0.0 + + + if min(precisions) > 0: + p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) + geo_mean = math.exp(p_log_sum) + else: + geo_mean = 0 + + ratio = float(translation_length) / reference_length + + if ratio > 1.0: + bp = 1. + else: + bp = math.exp(1 - 1. / ratio) + + bleu = geo_mean * bp + + + return (bleu, precisions, bp, ratio, translation_length, reference_length) + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class Bleu(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), + "references": datasets.Sequence( + datasets.Sequence(datasets.Value("string", id="token"), id="sequence"), id="references" + ), + } + ), + codebase_urls=["https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py"], + reference_urls=[ + "https://en.wikipedia.org/wiki/BLEU", + "https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213", + ], + ) + + def _compute(self, predictions, references, max_order=4, smooth=False): + score = compute_bleu( + reference_corpus=references, translation_corpus=predictions, max_order=1, smooth=smooth + ) + (bleu, precisions, bp, ratio, translation_length, reference_length) = score + return { + "bleu": bleu, + "precisions": precisions, + "brevity_penalty": bp, + "length_ratio": ratio, + "translation_length": translation_length, + "reference_length": reference_length, + } diff --git a/diana/downstreamdeca/metric/rouge_local/rouge_metric.py b/diana/downstreamdeca/metric/rouge_local/rouge_metric.py new file mode 100644 index 0000000..66eb7f1 --- /dev/null +++ b/diana/downstreamdeca/metric/rouge_local/rouge_metric.py @@ -0,0 +1,133 @@ +# Copyright 2020 The HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" ROUGE metric from Google Research github repo. """ + +# The dependencies in https://github.com/google-research/google-research/blob/master/rouge/requirements.txt +import absl # Here to have a nice missing dependency error message early on +import nltk # Here to have a nice missing dependency error message early on +import numpy # Here to have a nice missing dependency error message early on +import six # Here to have a nice missing dependency error message early on +from .rouge_score import RougeScorer +from .scoring import * +from .rouge_tokenizers import * +import datasets + + +_CITATION = """\ +@inproceedings{lin-2004-rouge, + title = "{ROUGE}: A Package for Automatic Evaluation of Summaries", + author = "Lin, Chin-Yew", + booktitle = "Text Summarization Branches Out", + month = jul, + year = "2004", + address = "Barcelona, Spain", + publisher = "Association for Computational Linguistics", + url = "https://www.aclweb.org/anthology/W04-1013", + pages = "74--81", +} +""" + +_DESCRIPTION = """\ +ROUGE, or Recall-Oriented Understudy for Gisting Evaluation, is a set of metrics and a software package used for +evaluating automatic summarization and machine translation software in natural language processing. +The metrics compare an automatically produced summary or translation against a reference or a set of references (human-produced) summary or translation. + +Note that ROUGE is case insensitive, meaning that upper case letters are treated the same way as lower case letters. + +This metrics is a wrapper around Google Research reimplementation of ROUGE: +https://github.com/google-research/google-research/tree/master/rouge +""" + +_KWARGS_DESCRIPTION = """ +Calculates average rouge scores for a list of hypotheses and references +Args: + predictions: list of predictions to score. Each prediction + should be a string with tokens separated by spaces. + references: list of reference for each prediction. Each + reference should be a string with tokens separated by spaces. + rouge_types: A list of rouge types to calculate. + Valid names: + `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring, + `"rougeL"`: Longest common subsequence based scoring. + `"rougeLSum"`: rougeLsum splits text using `"\n"`. + See details in https://github.com/huggingface/datasets/issues/617 + use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes. + use_aggregator: Return aggregates if this is set to True +Returns: + rouge1: rouge_1 (precision, recall, f1), + rouge2: rouge_2 (precision, recall, f1), + rougeL: rouge_l (precision, recall, f1), + rougeLsum: rouge_lsum (precision, recall, f1) +Examples: + + >>> rouge = datasets.load_metric('rouge') + >>> predictions = ["hello there", "general kenobi"] + >>> references = ["hello there", "general kenobi"] + >>> results = rouge.compute(predictions=predictions, references=references) + >>> print(list(results.keys())) + ['rouge1', 'rouge2', 'rougeL', 'rougeLsum'] + >>> print(results["rouge1"]) + AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)) + >>> print(results["rouge1"].mid.fmeasure) + 1.0 +""" + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class Rouge(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": datasets.Value("string", id="sequence"), + "references": datasets.Value("string", id="sequence"), + } + ), + codebase_urls=["https://github.com/google-research/google-research/tree/master/rouge"], + reference_urls=[ + "https://en.wikipedia.org/wiki/ROUGE_(metric)", + "https://github.com/google-research/google-research/tree/master/rouge", + ], + ) + + def _compute(self, predictions, references, rouge_types=None, use_aggregator=True, use_stemmer=False): + if rouge_types is None: + rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"] + + scorer = RougeScorer(rouge_types=rouge_types, use_stemmer=use_stemmer) + if use_aggregator: + aggregator = BootstrapAggregator() + else: + scores = [] + + for ref, pred in zip(references, predictions): + score = scorer.score(ref, pred) + if use_aggregator: + aggregator.add_scores(score) + else: + scores.append(score) + + if use_aggregator: + result = aggregator.aggregate() + for key in rouge_types: + result[key] = result[key].mid.fmeasure + else: + result = {} + for key in scores[0]: + result[key] = list(score[key] for score in scores) + + return result diff --git a/diana/downstreamdeca/metric/rouge_local/rouge_score.py b/diana/downstreamdeca/metric/rouge_local/rouge_score.py new file mode 100644 index 0000000..eaed59b --- /dev/null +++ b/diana/downstreamdeca/metric/rouge_local/rouge_score.py @@ -0,0 +1,318 @@ +# coding=utf-8 +# Copyright 2022 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Computes rouge scores between two text blobs. +Implementation replicates the functionality in the original ROUGE package. See: +Lin, Chin-Yew. ROUGE: a Package for Automatic Evaluation of Summaries. In +Proceedings of the Workshop on Text Summarization Branches Out (WAS 2004), +Barcelona, Spain, July 25 - 26, 2004. +Default options are equivalent to running: +ROUGE-1.5.5.pl -e data -n 2 -a settings.xml +Or with use_stemmer=True: +ROUGE-1.5.5.pl -m -e data -n 2 -a settings.xml +In these examples settings.xml lists input files and formats. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re + +from absl import logging +import nltk +import numpy as np +import six +from six.moves import map +from six.moves import range +from .scoring import * +from .rouge_tokenizers import * + + + +class RougeScorer(BaseScorer): + """Calculate rouges scores between two blobs of text. + Sample usage: + scorer = RougeScorer(['rouge1', 'rougeL'], use_stemmer=True) + scores = scorer.score('The quick brown fox jumps over the lazy dog', + 'The quick brown dog jumps on the log.') + """ + + def __init__(self, rouge_types, use_stemmer=False, split_summaries=False, + tokenizer=None): + """Initializes a new RougeScorer. + Valid rouge types that can be computed are: + rougen (e.g. rouge1, rouge2): n-gram based scoring. + rougeL: Longest common subsequence based scoring. + Args: + rouge_types: A list of rouge types to calculate. + use_stemmer: Bool indicating whether Porter stemmer should be used to + strip word suffixes to improve matching. This arg is used in the + DefaultTokenizer, but other tokenizers might or might not choose to + use this. + split_summaries: whether to add newlines between sentences for rougeLsum + tokenizer: Tokenizer object which has a tokenize() method. + Returns: + A dict mapping rouge types to Score tuples. + """ + + self.rouge_types = rouge_types + if tokenizer: + self._tokenizer = tokenizer + else: + self._tokenizer = DefaultTokenizer(use_stemmer) + logging.info("Using default tokenizer.") + + self._split_summaries = split_summaries + + def score_multi(self, targets, prediction): + """Calculates rouge scores between targets and prediction. + The target with the maximum f-measure is used for the final score for + each score type.. + Args: + targets: list of texts containing the targets + prediction: Text containing the predicted text. + Returns: + A dict mapping each rouge type to a Score object. + Raises: + ValueError: If an invalid rouge type is encountered. + """ + score_dicts = [self.score(t, prediction) for t in targets] + max_score = {} + for k in self.rouge_types: + index = np.argmax([s[k].fmeasure for s in score_dicts]) + max_score[k] = score_dicts[index][k] + + return max_score + + def score(self, target, prediction): + """Calculates rouge scores between the target and prediction. + Args: + target: Text containing the target (ground truth) text, + or if a list + prediction: Text containing the predicted text. + Returns: + A dict mapping each rouge type to a Score object. + Raises: + ValueError: If an invalid rouge type is encountered. + """ + + # Pre-compute target tokens and prediction tokens for use by different + # types, except if only "rougeLsum" is requested. + if len(self.rouge_types) == 1 and self.rouge_types[0] == "rougeLsum": + target_tokens = None + prediction_tokens = None + else: + target_tokens = self._tokenizer.tokenize(target) + prediction_tokens = self._tokenizer.tokenize(prediction) + result = {} + + for rouge_type in self.rouge_types: + if rouge_type == "rougeL": + # Rouge from longest common subsequences. + scores = _score_lcs(target_tokens, prediction_tokens) + elif rouge_type == "rougeLsum": + # Note: Does not support multi-line text. + def get_sents(text): + if self._split_summaries: + sents = nltk.sent_tokenize(text) + else: + # Assume sentences are separated by newline. + sents = six.ensure_str(text).split("\n") + sents = [x for x in sents if len(x)] + return sents + + target_tokens_list = [ + self._tokenizer.tokenize(s) for s in get_sents(target)] + prediction_tokens_list = [ + self._tokenizer.tokenize(s) for s in get_sents(prediction)] + + scores = _summary_level_lcs(target_tokens_list, + prediction_tokens_list) + elif re.match(r"rouge[0-9]$", six.ensure_str(rouge_type)): + # Rouge from n-grams. + n = int(rouge_type[5:]) + if n <= 0: + raise ValueError("rougen requires positive n: %s" % rouge_type) + target_ngrams = _create_ngrams(target_tokens, n) + prediction_ngrams = _create_ngrams(prediction_tokens, n) + scores = _score_ngrams(target_ngrams, prediction_ngrams) + else: + raise ValueError("Invalid rouge type: %s" % rouge_type) + result[rouge_type] = scores + + return result + + +def _create_ngrams(tokens, n): + """Creates ngrams from the given list of tokens. + Args: + tokens: A list of tokens from which ngrams are created. + n: Number of tokens to use, e.g. 2 for bigrams. + Returns: + A dictionary mapping each bigram to the number of occurrences. + """ + + ngrams = collections.Counter() + for ngram in (tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)): + ngrams[ngram] += 1 + return ngrams + + +def _score_lcs(target_tokens, prediction_tokens): + """Computes LCS (Longest Common Subsequence) rouge scores. + Args: + target_tokens: Tokens from the target text. + prediction_tokens: Tokens from the predicted text. + Returns: + A Score object containing computed scores. + """ + + if not target_tokens or not prediction_tokens: + return Score(precision=0, recall=0, fmeasure=0) + + # Compute length of LCS from the bottom up in a table (DP appproach). + lcs_table = _lcs_table(target_tokens, prediction_tokens) + lcs_length = lcs_table[-1][-1] + + precision = lcs_length / len(prediction_tokens) + recall = lcs_length / len(target_tokens) + fmeasure = func_fmeasure(precision, recall) + + return Score(precision=precision, recall=recall, fmeasure=fmeasure) + + +def _lcs_table(ref, can): + """Create 2-d LCS score table.""" + rows = len(ref) + cols = len(can) + lcs_table = [[0] * (cols + 1) for _ in range(rows + 1)] + for i in range(1, rows + 1): + for j in range(1, cols + 1): + if ref[i - 1] == can[j - 1]: + lcs_table[i][j] = lcs_table[i - 1][j - 1] + 1 + else: + lcs_table[i][j] = max(lcs_table[i - 1][j], lcs_table[i][j - 1]) + return lcs_table + + +def _backtrack_norec(t, ref, can): + """Read out LCS.""" + i = len(ref) + j = len(can) + lcs = [] + while i > 0 and j > 0: + if ref[i - 1] == can[j - 1]: + lcs.insert(0, i-1) + i -= 1 + j -= 1 + elif t[i][j - 1] > t[i - 1][j]: + j -= 1 + else: + i -= 1 + return lcs + + +def _summary_level_lcs(ref_sent, can_sent): + """ROUGE: Summary-level LCS, section 3.2 in ROUGE paper. + Args: + ref_sent: list of tokenized reference sentences + can_sent: list of tokenized candidate sentences + Returns: + summary level ROUGE score + """ + if not ref_sent or not can_sent: + return Score(precision=0, recall=0, fmeasure=0) + + m = sum(map(len, ref_sent)) + n = sum(map(len, can_sent)) + if not n or not m: + return Score(precision=0, recall=0, fmeasure=0) + + # get token counts to prevent double counting + token_cnts_r = collections.Counter() + token_cnts_c = collections.Counter() + for s in ref_sent: + # s is a list of tokens + token_cnts_r.update(s) + for s in can_sent: + token_cnts_c.update(s) + + hits = 0 + for r in ref_sent: + lcs = _union_lcs(r, can_sent) + # Prevent double-counting: + # The paper describes just computing hits += len(_union_lcs()), + # but the implementation prevents double counting. We also + # implement this as in version 1.5.5. + for t in lcs: + if token_cnts_c[t] > 0 and token_cnts_r[t] > 0: + hits += 1 + token_cnts_c[t] -= 1 + token_cnts_r[t] -= 1 + + recall = hits / m + precision = hits / n + fmeasure = func_fmeasure(precision, recall) + return Score(precision=precision, recall=recall, fmeasure=fmeasure) + + +def _union_lcs(ref, c_list): + """Find union LCS between a ref sentence and list of candidate sentences. + Args: + ref: list of tokens + c_list: list of list of indices for LCS into reference summary + Returns: + List of tokens in ref representing union LCS. + """ + lcs_list = [lcs_ind(ref, c) for c in c_list] + return [ref[i] for i in _find_union(lcs_list)] + + +def _find_union(lcs_list): + """Finds union LCS given a list of LCS.""" + return sorted(list(set().union(*lcs_list))) + + +def lcs_ind(ref, can): + """Returns one of the longest lcs.""" + t = _lcs_table(ref, can) + return _backtrack_norec(t, ref, can) + + +def _score_ngrams(target_ngrams, prediction_ngrams): + """Compute n-gram based rouge scores. + Args: + target_ngrams: A Counter object mapping each ngram to number of + occurrences for the target text. + prediction_ngrams: A Counter object mapping each ngram to number of + occurrences for the prediction text. + Returns: + A Score object containing computed scores. + """ + + intersection_ngrams_count = 0 + for ngram in six.iterkeys(target_ngrams): + intersection_ngrams_count += min(target_ngrams[ngram], + prediction_ngrams[ngram]) + target_ngrams_count = sum(target_ngrams.values()) + prediction_ngrams_count = sum(prediction_ngrams.values()) + + precision = intersection_ngrams_count / max(prediction_ngrams_count, 1) + recall = intersection_ngrams_count / max(target_ngrams_count, 1) + fmeasure = func_fmeasure(precision, recall) + + return Score(precision=precision, recall=recall, fmeasure=fmeasure) \ No newline at end of file diff --git a/diana/downstreamdeca/metric/rouge_local/rouge_tokenizers.py b/diana/downstreamdeca/metric/rouge_local/rouge_tokenizers.py new file mode 100644 index 0000000..91e835b --- /dev/null +++ b/diana/downstreamdeca/metric/rouge_local/rouge_tokenizers.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright 2022 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +"""Library containing Tokenizer definitions. +The RougeScorer class can be instantiated with the tokenizers defined here. New +tokenizers can be defined by creating a subclass of the Tokenizer abstract class +and overriding the tokenize() method. +""" +import abc +from nltk.stem import porter + + +import re +import six + + +# Pre-compile regexes that are use often +NON_ALPHANUM_PATTERN = r"[^a-z0-9]+" +NON_ALPHANUM_RE = re.compile(NON_ALPHANUM_PATTERN) +SPACES_PATTERN = r"\s+" +SPACES_RE = re.compile(SPACES_PATTERN) +VALID_TOKEN_PATTERN = r"^[a-z0-9]+$" +VALID_TOKEN_RE = re.compile(VALID_TOKEN_PATTERN) + + +def _tokenize(text, stemmer): + """Tokenize input text into a list of tokens. + This approach aims to replicate the approach taken by Chin-Yew Lin in + the original ROUGE implementation. + Args: + text: A text blob to tokenize. + stemmer: An optional stemmer. + Returns: + A list of string tokens extracted from input text. + """ + + # Convert everything to lowercase. + text = text.lower() + # Replace any non-alpha-numeric characters with spaces. + text = NON_ALPHANUM_RE.sub(" ", six.ensure_str(text)) + + tokens = SPACES_RE.split(text) + if stemmer: + # Only stem words more than 3 characters long. + tokens = [six.ensure_str(stemmer.stem(x)) if len(x) > 3 else x + for x in tokens] + + # One final check to drop any empty or invalid tokens. + tokens = [x for x in tokens if VALID_TOKEN_RE.match(x)] + + return tokens + + +class Tokenizer(abc.ABC): + """Abstract base class for a tokenizer. + Subclasses of Tokenizer must implement the tokenize() method. + """ + + @abc.abstractmethod + def tokenize(self, text): + raise NotImplementedError("Tokenizer must override tokenize() method") + + +class DefaultTokenizer(Tokenizer): + """Default tokenizer which tokenizes on whitespace.""" + + def __init__(self, use_stemmer=False): + """Constructor for DefaultTokenizer. + Args: + use_stemmer: boolean, indicating whether Porter stemmer should be used to + strip word suffixes to improve matching. + """ + self._stemmer = porter.PorterStemmer() if use_stemmer else None + + def tokenize(self, text): + return _tokenize(text, self._stemmer) \ No newline at end of file diff --git a/diana/downstreamdeca/metric/rouge_local/scoring.py b/diana/downstreamdeca/metric/rouge_local/scoring.py new file mode 100644 index 0000000..0a924ae --- /dev/null +++ b/diana/downstreamdeca/metric/rouge_local/scoring.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright 2022 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library for scoring and evaluation of text samples. +Aggregation functions use bootstrap resampling to compute confidence intervals +as per the original ROUGE perl implementation. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import collections +from typing import Dict + +import numpy as np +import six +from six.moves import range + + +class Score( + collections.namedtuple("Score", ["precision", "recall", "fmeasure"])): + """Tuple containing precision, recall, and f-measure values.""" + + +class BaseScorer(object, metaclass=abc.ABCMeta): + """Base class for Scorer objects.""" + + @abc.abstractmethod + def score(self, target, prediction): + """Calculates score between the target and prediction. + Args: + target: Text containing the target (ground truth) text. + prediction: Text containing the predicted text. + Returns: + A dict mapping each score_type (string) to Score object. + """ + + +class AggregateScore( + collections.namedtuple("AggregateScore", ["low", "mid", "high"])): + """Tuple containing confidence intervals for scores.""" + + +class BootstrapAggregator(object): + """Aggregates scores to provide confidence intervals. + Sample usage: + scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL']) + aggregator = Aggregator() + aggregator.add_scores(scorer.score("one two three", "one two")) + aggregator.add_scores(scorer.score("one two five six", "seven eight")) + result = aggregator.aggregate() + print result + {'rougeL': AggregateScore( + low=Score(precision=0.0, recall=0.0, fmeasure=0.0), + mid=Score(precision=0.5, recall=0.33, fmeasure=0.40), + high=Score(precision=1.0, recall=0.66, fmeasure=0.80)), + 'rouge1': AggregateScore( + low=Score(precision=0.0, recall=0.0, fmeasure=0.0), + mid=Score(precision=0.5, recall=0.33, fmeasure=0.40), + high=Score(precision=1.0, recall=0.66, fmeasure=0.80))} + """ + + def __init__(self, confidence_interval=0.95, n_samples=1000): + """Initializes a BootstrapAggregator object. + Args: + confidence_interval: Confidence interval to compute on the mean as a + decimal. + n_samples: Number of samples to use for bootstrap resampling. + Raises: + ValueError: If invalid argument is given. + """ + + if confidence_interval < 0 or confidence_interval > 1: + raise ValueError("confidence_interval must be in range [0, 1]") + if n_samples <= 0: + raise ValueError("n_samples must be positive") + + self._n_samples = n_samples + self._confidence_interval = confidence_interval + self._scores = collections.defaultdict(list) + + def add_scores(self, scores): + """Adds a sample for future aggregation. + Args: + scores: Dict mapping score_type strings to a namedtuple object/class + representing a score. + """ + + for score_type, score in six.iteritems(scores): + self._scores[score_type].append(score) + + def aggregate(self): + """Aggregates scores previously added using add_scores. + Returns: + A dict mapping score_type to AggregateScore objects. + """ + + result = {} + for score_type, scores in six.iteritems(self._scores): + # Stack scores into a 2-d matrix of (sample, measure). + score_matrix = np.vstack(tuple(scores)) + # Percentiles are returned as (interval, measure). + percentiles = self._bootstrap_resample(score_matrix) + # Extract the three intervals (low, mid, high). + intervals = tuple( + (scores[0].__class__(*percentiles[j, :]) for j in range(3))) + result[score_type] = AggregateScore( + low=intervals[0], mid=intervals[1], high=intervals[2]) + return result + + def _bootstrap_resample(self, matrix): + """Performs bootstrap resampling on a matrix of scores. + Args: + matrix: A 2-d matrix of (sample, measure). + Returns: + A 2-d matrix of (bounds, measure). There are three bounds: low (row 0), + mid (row 1) and high (row 2). Mid is always the mean, while low and high + bounds are specified by self._confidence_interval (which defaults to 0.95 + meaning it will return the 2.5th and 97.5th percentiles for a 95% + confidence interval on the mean). + """ + + # Matrix of (bootstrap sample, measure). + sample_mean = np.zeros((self._n_samples, matrix.shape[1])) + for i in range(self._n_samples): + sample_idx = np.random.choice( + np.arange(matrix.shape[0]), size=matrix.shape[0]) + sample = matrix[sample_idx, :] + sample_mean[i, :] = np.mean(sample, axis=0) + + # Take percentiles on the estimate of the mean using bootstrap samples. + # Final result is a (bounds, measure) matrix. + percentile_delta = (1 - self._confidence_interval) / 2 + q = 100 * np.array([percentile_delta, 0.5, 1 - percentile_delta]) + return np.percentile(sample_mean, q, axis=0) + + +def func_fmeasure(precision, recall): + """Computes f-measure given precision and recall values.""" + + if precision + recall > 0: + return 2 * precision * recall / (precision + recall) + else: + return 0.0 \ No newline at end of file diff --git a/diana/downstreamdeca/metric/squad_v1_local/evaluate.py b/diana/downstreamdeca/metric/squad_v1_local/evaluate.py new file mode 100644 index 0000000..41effa7 --- /dev/null +++ b/diana/downstreamdeca/metric/squad_v1_local/evaluate.py @@ -0,0 +1,92 @@ +""" Official evaluation script for v1.1 of the SQuAD dataset. """ + +import argparse +import json +import re +import string +import sys +from collections import Counter + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match_score(prediction, ground_truth): + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def evaluate(dataset, predictions): + f1 = exact_match = total = 0 + for article in dataset: + for paragraph in article["paragraphs"]: + for qa in paragraph["qas"]: + total += 1 + if qa["id"] not in predictions: + message = "Unanswered question " + qa["id"] + " will receive score 0." + print(message, file=sys.stderr) + continue + ground_truths = list(map(lambda x: x["text"], qa["answers"])) + prediction = predictions[qa["id"]] + exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) + f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths) + + exact_match = 100.0 * exact_match / total + f1 = 100.0 * f1 / total + + return {"exact_match": exact_match, "f1": f1} + + +if __name__ == "__main__": + expected_version = "1.1" + parser = argparse.ArgumentParser(description="Evaluation for SQuAD " + expected_version) + parser.add_argument("dataset_file", help="Dataset file") + parser.add_argument("prediction_file", help="Prediction File") + args = parser.parse_args() + with open(args.dataset_file) as dataset_file: + dataset_json = json.load(dataset_file) + if dataset_json["version"] != expected_version: + print( + "Evaluation expects v-" + expected_version + ", but got dataset with v-" + dataset_json["version"], + file=sys.stderr, + ) + dataset = dataset_json["data"] + with open(args.prediction_file) as prediction_file: + predictions = json.load(prediction_file) + print(json.dumps(evaluate(dataset, predictions))) diff --git a/diana/downstreamdeca/metric/squad_v1_local/squad_v1_local.py b/diana/downstreamdeca/metric/squad_v1_local/squad_v1_local.py new file mode 100644 index 0000000..1456f63 --- /dev/null +++ b/diana/downstreamdeca/metric/squad_v1_local/squad_v1_local.py @@ -0,0 +1,109 @@ +# Copyright 2020 The HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" SQuAD metric. """ + +import datasets + +from .evaluate import evaluate + + +_CITATION = """\ +@inproceedings{Rajpurkar2016SQuAD10, + title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text}, + author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang}, + booktitle={EMNLP}, + year={2016} +} +""" + +_DESCRIPTION = """ +This metric wrap the official scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD). + +Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by +crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, +from the corresponding reading passage, or the question might be unanswerable. +""" + +_KWARGS_DESCRIPTION = """ +Computes SQuAD scores (F1 and EM). +Args: + predictions: List of question-answers dictionaries with the following key-values: + - 'id': id of the question-answer pair as given in the references (see below) + - 'prediction_text': the text of the answer + references: List of question-answers dictionaries with the following key-values: + - 'id': id of the question-answer pair (see above), + - 'answers': a Dict in the SQuAD dataset format + { + 'text': list of possible texts for the answer, as a list of strings + 'answer_start': list of start positions for the answer, as a list of ints + } + Note that answer_start values are not taken into account to compute the metric. +Returns: + 'exact_match': Exact match (the normalized answer exactly match the gold answer) + 'f1': The F-score of predicted tokens versus the gold answer +Examples: + + >>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22'}] + >>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}] + >>> squad_metric = datasets.load_metric("squad") + >>> results = squad_metric.compute(predictions=predictions, references=references) + >>> print(results) + {'exact_match': 100.0, 'f1': 100.0} +""" + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class Squad(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": {"id": datasets.Value("string"), "prediction_text": datasets.Value("string")}, + "references": { + "id": datasets.Value("string"), + "answers": datasets.features.Sequence( + { + "text": datasets.Value("string"), + "answer_start": datasets.Value("int32"), + } + ), + }, + } + ), + codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], + reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], + ) + + def _compute(self, predictions, references): + pred_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions} + dataset = [ + { + "paragraphs": [ + { + "qas": [ + { + "answers": [{"text": answer_text} for answer_text in ref["answers"]["text"]], + "id": ref["id"], + } + for ref in references + ] + } + ] + } + ] + score = evaluate(dataset=dataset, predictions=pred_dict) + return score diff --git a/diana/downstreamdeca/metric/squad_v2_local/evaluate.py b/diana/downstreamdeca/metric/squad_v2_local/evaluate.py new file mode 100644 index 0000000..1a04e53 --- /dev/null +++ b/diana/downstreamdeca/metric/squad_v2_local/evaluate.py @@ -0,0 +1,330 @@ +"""Official evaluation script for SQuAD version 2.0. + +In addition to basic functionality, we also compute additional statistics and +plot precision-recall curves if an additional na_prob.json file is provided. +This file is expected to map question ID's to the model's predicted probability +that a question is unanswerable. +""" +import argparse +import collections +import json +import os +import re +import string +import sys + +import numpy as np + + +OPTS = None + + +def parse_args(): + parser = argparse.ArgumentParser("Official evaluation script for SQuAD version 2.0.") + parser.add_argument("data_file", metavar="data.json", help="Input data JSON file.") + parser.add_argument("pred_file", metavar="pred.json", help="Model predictions.") + parser.add_argument( + "--out-file", "-o", metavar="eval.json", help="Write accuracy metrics to file (default is stdout)." + ) + parser.add_argument( + "--na-prob-file", "-n", metavar="na_prob.json", help="Model estimates of probability of no answer." + ) + parser.add_argument( + "--na-prob-thresh", + "-t", + type=float, + default=1.0, + help='Predict "" if no-answer probability exceeds this (default = 1.0).', + ) + parser.add_argument( + "--out-image-dir", "-p", metavar="out_images", default=None, help="Save precision-recall curves to directory." + ) + parser.add_argument("--verbose", "-v", action="store_true") + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + return parser.parse_args() + + +def make_qid_to_has_ans(dataset): + qid_to_has_ans = {} + for article in dataset: + for p in article["paragraphs"]: + for qa in p["qas"]: + qid_to_has_ans[qa["id"]] = bool(qa["answers"]["text"]) + return qid_to_has_ans + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) + return re.sub(regex, " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def get_tokens(s): + if not s: + return [] + return normalize_answer(s).split() + + +def compute_exact(a_gold, a_pred): + return int(normalize_answer(a_gold) == normalize_answer(a_pred)) + + +def compute_f1(a_gold, a_pred): + gold_toks = get_tokens(a_gold) + pred_toks = get_tokens(a_pred) + common = collections.Counter(gold_toks) & collections.Counter(pred_toks) + num_same = sum(common.values()) + if len(gold_toks) == 0 or len(pred_toks) == 0: + # If either is no-answer, then F1 is 1 if they agree, 0 otherwise + return int(gold_toks == pred_toks) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(pred_toks) + recall = 1.0 * num_same / len(gold_toks) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def get_raw_scores(dataset, preds): + exact_scores = {} + f1_scores = {} + for article in dataset: + for p in article["paragraphs"]: + for qa in p["qas"]: + qid = qa["id"] + if qid not in preds: + print(f"Missing prediction for {qid}") + continue + a_pred = preds[qid] + + gold_answers = [t for t in qa["answers"]["text"] if normalize_answer(t)] + if not gold_answers: + # For unanswerable questions, only correct answer is empty string + gold_answers = [""] + if a_pred != "": + exact_scores[qid] = 0 + f1_scores[qid] = 0 + else: + exact_scores[qid] = 1 + f1_scores[qid] = 1 + else: + # Take max over all gold answers + exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) + f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) + return exact_scores, f1_scores + + +def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): + new_scores = {} + for qid, s in scores.items(): + pred_na = na_probs[qid] > na_prob_thresh + if pred_na: + new_scores[qid] = float(not qid_to_has_ans[qid]) + else: + new_scores[qid] = s + return new_scores + + +def make_eval_dict(exact_scores, f1_scores, qid_list=None): + if not qid_list: + total = len(exact_scores) + return collections.OrderedDict( + [ + ("exact", 100.0 * sum(exact_scores.values()) / total), + ("f1", 100.0 * sum(f1_scores.values()) / total), + ("total", total), + ] + ) + else: + total = len(qid_list) + return collections.OrderedDict( + [ + ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total), + ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total), + ("total", total), + ] + ) + + +def merge_eval(main_eval, new_eval, prefix): + for k in new_eval: + main_eval[f"{prefix}_{k}"] = new_eval[k] + + +def plot_pr_curve(precisions, recalls, out_image, title): + plt.step(recalls, precisions, color="b", alpha=0.2, where="post") + plt.fill_between(recalls, precisions, step="post", alpha=0.2, color="b") + plt.xlabel("Recall") + plt.ylabel("Precision") + plt.xlim([0.0, 1.05]) + plt.ylim([0.0, 1.05]) + plt.title(title) + plt.savefig(out_image) + plt.clf() + + +def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None): + qid_list = sorted(na_probs, key=lambda k: na_probs[k]) + true_pos = 0.0 + cur_p = 1.0 + cur_r = 0.0 + precisions = [1.0] + recalls = [0.0] + avg_prec = 0.0 + for i, qid in enumerate(qid_list): + if qid_to_has_ans[qid]: + true_pos += scores[qid] + cur_p = true_pos / float(i + 1) + cur_r = true_pos / float(num_true_pos) + if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]: + # i.e., if we can put a threshold after this point + avg_prec += cur_p * (cur_r - recalls[-1]) + precisions.append(cur_p) + recalls.append(cur_r) + if out_image: + plot_pr_curve(precisions, recalls, out_image, title) + return {"ap": 100.0 * avg_prec} + + +def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, out_image_dir): + if out_image_dir and not os.path.exists(out_image_dir): + os.makedirs(out_image_dir) + num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) + if num_true_pos == 0: + return + pr_exact = make_precision_recall_eval( + exact_raw, + na_probs, + num_true_pos, + qid_to_has_ans, + out_image=os.path.join(out_image_dir, "pr_exact.png"), + title="Precision-Recall curve for Exact Match score", + ) + pr_f1 = make_precision_recall_eval( + f1_raw, + na_probs, + num_true_pos, + qid_to_has_ans, + out_image=os.path.join(out_image_dir, "pr_f1.png"), + title="Precision-Recall curve for F1 score", + ) + oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} + pr_oracle = make_precision_recall_eval( + oracle_scores, + na_probs, + num_true_pos, + qid_to_has_ans, + out_image=os.path.join(out_image_dir, "pr_oracle.png"), + title="Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)", + ) + merge_eval(main_eval, pr_exact, "pr_exact") + merge_eval(main_eval, pr_f1, "pr_f1") + merge_eval(main_eval, pr_oracle, "pr_oracle") + + +def histogram_na_prob(na_probs, qid_list, image_dir, name): + if not qid_list: + return + x = [na_probs[k] for k in qid_list] + weights = np.ones_like(x) / float(len(x)) + plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) + plt.xlabel("Model probability of no-answer") + plt.ylabel("Proportion of dataset") + plt.title(f"Histogram of no-answer probability: {name}") + plt.savefig(os.path.join(image_dir, f"na_prob_hist_{name}.png")) + plt.clf() + + +def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): + num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) + cur_score = num_no_ans + best_score = cur_score + best_thresh = 0.0 + qid_list = sorted(na_probs, key=lambda k: na_probs[k]) + for i, qid in enumerate(qid_list): + if qid not in scores: + continue + if qid_to_has_ans[qid]: + diff = scores[qid] + else: + if preds[qid]: + diff = -1 + else: + diff = 0 + cur_score += diff + if cur_score > best_score: + best_score = cur_score + best_thresh = na_probs[qid] + return 100.0 * best_score / len(scores), best_thresh + + +def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): + best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) + best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) + main_eval["best_exact"] = best_exact + main_eval["best_exact_thresh"] = exact_thresh + main_eval["best_f1"] = best_f1 + main_eval["best_f1_thresh"] = f1_thresh + + +def main(): + with open(OPTS.data_file) as f: + dataset_json = json.load(f) + dataset = dataset_json["data"] + with open(OPTS.pred_file) as f: + preds = json.load(f) + if OPTS.na_prob_file: + with open(OPTS.na_prob_file) as f: + na_probs = json.load(f) + else: + na_probs = {k: 0.0 for k in preds} + qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False + has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] + no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] + exact_raw, f1_raw = get_raw_scores(dataset, preds) + exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh) + f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh) + out_eval = make_eval_dict(exact_thresh, f1_thresh) + if has_ans_qids: + has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) + merge_eval(out_eval, has_ans_eval, "HasAns") + if no_ans_qids: + no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) + merge_eval(out_eval, no_ans_eval, "NoAns") + if OPTS.na_prob_file: + find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) + if OPTS.na_prob_file and OPTS.out_image_dir: + run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, OPTS.out_image_dir) + histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, "hasAns") + histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, "noAns") + if OPTS.out_file: + with open(OPTS.out_file, "w") as f: + json.dump(out_eval, f) + else: + print(json.dumps(out_eval, indent=2)) + + +if __name__ == "__main__": + OPTS = parse_args() + if OPTS.out_image_dir: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + main() \ No newline at end of file diff --git a/diana/downstreamdeca/metric/squad_v2_local/squad_v2_local.py b/diana/downstreamdeca/metric/squad_v2_local/squad_v2_local.py new file mode 100644 index 0000000..d88979a --- /dev/null +++ b/diana/downstreamdeca/metric/squad_v2_local/squad_v2_local.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" SQuAD v2 metric. """ + +import datasets + +from .evaluate import ( + apply_no_ans_threshold, + find_all_best_thresh, + get_raw_scores, + make_eval_dict, + make_qid_to_has_ans, + merge_eval, +) + + +_CITATION = """\ +@inproceedings{Rajpurkar2016SQuAD10, + title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text}, + author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang}, + booktitle={EMNLP}, + year={2016} +} +""" + +_DESCRIPTION = """ +This metric wrap the official scoring script for version 2 of the Stanford Question +Answering Dataset (SQuAD). + +Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by +crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, +from the corresponding reading passage, or the question might be unanswerable. + +SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable questions +written adversarially by crowdworkers to look similar to answerable ones. +To do well on SQuAD2.0, systems must not only answer questions when possible, but also +determine when no answer is supported by the paragraph and abstain from answering. +""" + +_KWARGS_DESCRIPTION = """ +Computes SQuAD v2 scores (F1 and EM). +Args: + predictions: List of triple for question-answers to score with the following elements: + - the question-answer 'id' field as given in the references (see below) + - the text of the answer + - the probability that the question has no answer + references: List of question-answers dictionaries with the following key-values: + - 'id': id of the question-answer pair (see above), + - 'answers': a list of Dict {'text': text of the answer as a string} + no_answer_threshold: float + Probability threshold to decide that a question has no answer. +Returns: + 'exact': Exact match (the normalized answer exactly match the gold answer) + 'f1': The F-score of predicted tokens versus the gold answer + 'total': Number of score considered + 'HasAns_exact': Exact match (the normalized answer exactly match the gold answer) + 'HasAns_f1': The F-score of predicted tokens versus the gold answer + 'HasAns_total': Number of score considered + 'NoAns_exact': Exact match (the normalized answer exactly match the gold answer) + 'NoAns_f1': The F-score of predicted tokens versus the gold answer + 'NoAns_total': Number of score considered + 'best_exact': Best exact match (with varying threshold) + 'best_exact_thresh': No-answer probability threshold associated to the best exact match + 'best_f1': Best F1 (with varying threshold) + 'best_f1_thresh': No-answer probability threshold associated to the best F1 +Examples: + + >>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22', 'no_answer_probability': 0.}] + >>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}] + >>> squad_v2_metric = datasets.load_metric("squad_v2") + >>> results = squad_v2_metric.compute(predictions=predictions, references=references) + >>> print(results) + {'exact': 100.0, 'f1': 100.0, 'total': 1, 'HasAns_exact': 100.0, 'HasAns_f1': 100.0, 'HasAns_total': 1, 'best_exact': 100.0, 'best_exact_thresh': 0.0, 'best_f1': 100.0, 'best_f1_thresh': 0.0} +""" + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class SquadV2Local(datasets.Metric): + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "predictions": { + "id": datasets.Value("string"), + "prediction_text": datasets.Value("string"), + "no_answer_probability": datasets.Value("float32"), + }, + "references": { + "id": datasets.Value("string"), + "answers": datasets.features.Sequence( + {"text": datasets.Value("string"), "answer_start": datasets.Value("int32")} + ), + }, + } + ), + codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], + reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"], + ) + + def _compute(self, predictions, references, no_answer_threshold=1.0): + # no_answer_probabilities = dict((p["id"], p["no_answer_probability"]) for p in predictions) + dataset = [{"paragraphs": [{"qas": references}]}] + predictions = dict((p["id"], p["prediction_text"]) for p in predictions) + + # qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False + # has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] + # no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] + + exact_raw, f1_raw = get_raw_scores(dataset, predictions) + out_eval = make_eval_dict(exact_raw, f1_raw) + # + # + # exact_thresh = apply_no_ans_threshold(exact_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold) + # f1_thresh = apply_no_ans_threshold(f1_raw, no_answer_probabilities, qid_to_has_ans, no_answer_threshold) + # out_eval = make_eval_dict(exact_thresh, f1_thresh) + # + # if has_ans_qids: + # has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) + # merge_eval(out_eval, has_ans_eval, "HasAns") + # if no_ans_qids: + # no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) + # merge_eval(out_eval, no_ans_eval, "NoAns") + # find_all_best_thresh(out_eval, predictions, exact_raw, f1_raw, no_answer_probabilities, qid_to_has_ans) + return dict(out_eval) \ No newline at end of file diff --git a/diana/downstreamdeca/simple_processors.py b/diana/downstreamdeca/simple_processors.py new file mode 100755 index 0000000..6505679 --- /dev/null +++ b/diana/downstreamdeca/simple_processors.py @@ -0,0 +1,192 @@ +import sys +sys.path.append('../data_process') +from typing import List, Optional, Tuple +from QAInput import SimpleQAInput as QAInput +#from QAInput import QAInputNoPrompt as QAInput + +def preprocess_sqaud_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_sqaud_abstractive_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_boolq_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + question_column, context_column, answer_column = 'question', 'passage', 'answer' + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)] + targets = [str(ans) for ans in answers] #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + # print(inputs,targets) + return inputs, targets + +def preprocess_boolq_batch_pretrain( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_boolqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0].capitalize() if len(answer["text"]) > 0 else "" for answer in answers] + # print(inputs,targets) + return inputs, targets + +def preprocess_narrativeqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = [exp['summary']['text'] for exp in examples['document']] + questions = [exp['text'] for exp in examples['question']] + answers = [ans[0]['text'] for ans in examples['answers']] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = answers #[answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_narrativeqa_batch_pretrain( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_drop_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = examples['passage'] + questions = examples['question'] + answers = examples['answers'] + inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_race_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = examples['article'] + questions = examples['question'] + all_options = examples['options'] + answers = examples['answer'] + options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}; D. {options[3]}' for options in all_options] + inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)] + ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3 } + targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)] + return inputs, targets + +def preprocess_newsqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + inputs = [QAInput.qg_input_extractive_qa(context, question) for question, context in zip(questions, contexts)] + # inputs = [QAInput.qg_input_abstrativeqa(context, question) for question, context in zip(questions, contexts)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + + +def preprocess_ropes_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples[question_column] + backgrounds = examples["background"] + situations = examples["situation"] + answers = examples[answer_column] + + inputs = [QAInput.qg_input_extractive_qa(" ".join([background, situation]), question) for question, background, situation in zip(questions, backgrounds, situations)] + targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] + return inputs, targets + +def preprocess_openbookqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + questions = examples['question_stem'] + all_options = examples['choices'] + answers = examples['answerKey'] + options_texts = [f"options: A. {options['text'][0]}; B. {options['text'][1]}; C. {options['text'][2]}; D. {options['text'][3]}" for options in all_options] + inputs = [QAInput.qg_input_multirc("", question, ops) for question, ops in zip(questions, options_texts)] + ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3} + targets = [options['text'][ans_map[answer]] for options, answer in zip(all_options, answers)] + return inputs, targets + +def preprocess_social_iqa_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = examples['article'] + questions = examples['question'] + all_options = examples['options'] + answers = examples['answer'] + options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options] + inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)] + ans_map = {'A': 0, 'B': 1, 'C': 2,} + targets = [options[ans_map[answer]] for options, answer in zip(all_options, answers)] + return inputs, targets + +def preprocess_dream_batch( + examples, + question_column: str, + context_column: str, + answer_column: str, +) -> Tuple[List[str], List[str]]: + contexts = [" ".join(dialogue) for dialogue in examples['dialogue']] + questions = examples['question'] + all_options = examples['choice'] + answers = examples['answer'] + answer_idxs = [options.index(answer) for answer, options in zip(answers, all_options)] + options_texts = [f'options: A. {options[0]}; B. {options[1]}; C. {options[2]}' for options in all_options] + inputs = [QAInput.qg_input_multirc(context, question, ops) for question, context, ops in zip(questions, contexts, options_texts)] + targets = answers + return inputs, targets \ No newline at end of file diff --git a/diana/downstreamdeca/toknize.py b/diana/downstreamdeca/toknize.py new file mode 100644 index 0000000..b30a7c7 --- /dev/null +++ b/diana/downstreamdeca/toknize.py @@ -0,0 +1,248 @@ +import itertools +import json +import os +import csv +import errno +import random +from random import shuffle +from typing import List +#import spacy +from tqdm import tqdm +import codecs +#import nltk +import glob +import xml.etree.ElementTree as ET +from datasets import load_dataset +#import statistic +from QAInput import StructuralQAInput, SimpleQAInput +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + M2M100Tokenizer, + MBart50Tokenizer, + EvalPrediction, + MBart50TokenizerFast, + MBartTokenizer, + MBartTokenizerFast, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + default_data_collator, + set_seed, +) + +tokenizer = AutoTokenizer.from_pretrained("./t5-base") +#nltk.download("stopwords") +#from nltk.corpus import stopwords +task2id = {"wikisql":0,"squad":1,"srl":2,"sst":3,"woz.en":4,"iwslt.en.de":5,"cnn_dailymail":6,"multinli.in.out":7,"zre":8} +task2format={"wikisql":1,"squad":0,"srl":0,"sst":2,"woz.en":1,"iwslt.en.de":0,"cnn_dailymail":1,"multinli.in.out":2,"zre":0} +#STOPWORDS = stopwords.words("english") +#STOPWORDS = [stopword + " " for stopword in STOPWORDS] +#nlp = spacy.load("en_core_web_sm") + +def chuli_example(examples,fname): + cur_dataset = fname + + answers_start = [] + if fname in ["none"]: + addition_answers = open(fname+"_answers.json",'r') + faddition = json.load(addition_answers) + for item in faddition: + answers_start.append({"text":item,"answer_start":[0]*len(item)}) + else: + for item in examples["answer"]: + answers_start.append({"text":item,"answer_start":[0]*len(item)}) + print(answers_start[:10]) + examples = examples.add_column("answers",answers_start) + return examples + + +def preprocess_plain( + examples, + question_column: str, + context_column: str, + answer_column:str): + + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + + inputs = [SimpleQAInput.qg_input(question, context) for question,context in zip(questions,contexts)] + answers = [_[0] for _ in answers] + return inputs,answers + + +def preprocess_proqa( + examples, + question_column: str, + context_column: str, + answer_column:str): + + questions = examples[question_column] + contexts = examples[context_column] + answers = examples[answer_column] + + inputs = [StructuralQAInput.qg_input(question, context) for question,context in zip(questions,contexts)] + answers = [_[0] for _ in answers] + return inputs,answers + + + + +def preprocess_function(examples): + preprocess_fn = preprocess_proqa#dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, "question","context","answer") + model_inputs = tokenizer(inputs, max_length=1024, padding=False, truncation=True) + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + labels = tokenizer(targets, max_length=128, padding=False, truncation=True) + + model_inputs["labels"] = labels["input_ids"] + gen_prompt_ids = [-(i+1) for i in range(1520,1520+10)] + format_id = task2format[dataset_name] + format_prompt_id_start = 300 + format_prompt_ids = [-(i+1) for i in range(format_prompt_id_start + format_id * 10, + format_prompt_id_start + (format_id + 1) * 10)] + task_id = task2id[dataset_name] + task_prompt_id_start = 0 + task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * 20, + task_prompt_id_start + (task_id + 1) * 20)] + + + domain_prompt_id_start = 20*30 + domain_prompt_number = 20 + domain_prompt_ids = [- (i + 1) for i in range(domain_prompt_id_start, + domain_prompt_id_start + 20)]*5 + input_ids = copy.deepcopy( + [gen_prompt_ids+format_prompt_ids+task_prompt_ids + domain_prompt_ids+input_ids for input_ids in model_inputs['input_ids']]) + model_inputs['input_ids'] = input_ids # [format_prompt_ids+input_ids for input_ids in model_inputs['input_ids']] + model_inputs['attention_mask'] = [[1] * 140 + attention_mask for attention_mask in + model_inputs['attention_mask']] + return model_inputs + +def preprocess_function_valid(examples): + preprocess_fn = preprocess_proqa#dataset_name_to_func(data_args.dataset_name) + inputs, targets = preprocess_fn(examples, "question","context","answer") + model_inputs = tokenizer(inputs, max_length=1024, padding=False, truncation=True) + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + + labels = tokenizer(targets, max_length=128, padding=False, truncation=True) + model_inputs["example_id"] = [] + + for i in range(len(model_inputs["input_ids"])): + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = i #sample_mapping[i] + model_inputs["example_id"].append(examples["id"][sample_index]) + + + + gen_prompt_ids = [-(i+1) for i in range(1520,1520+10)] + format_id = task2format[dataset_name] + format_prompt_id_start = 300 + format_prompt_ids = [-(i+1) for i in range(format_prompt_id_start + format_id * 10, + format_prompt_id_start + (format_id + 1) * 10)] + task_id = task2id[dataset_name] + task_prompt_id_start = 0 + task_prompt_ids = [- (i + 1) for i in range(task_prompt_id_start + task_id * 20, + task_prompt_id_start + (task_id + 1) * 20)] + + + domain_prompt_id_start = 20*30 + domain_prompt_number = 20 + domain_prompt_ids = [- (i + 1) for i in range(domain_prompt_id_start, + domain_prompt_id_start + 20)]*5 + input_ids = copy.deepcopy( + [gen_prompt_ids+format_prompt_ids+task_prompt_ids + domain_prompt_ids+input_ids for input_ids in model_inputs['input_ids']]) + model_inputs['input_ids'] = input_ids # [format_prompt_ids+input_ids for input_ids in model_inputs['input_ids']] + model_inputs['attention_mask'] = [[1] * 140 + attention_mask for attention_mask in + model_inputs['attention_mask']] + model_inputs["labels"] = labels["input_ids"] + + + return model_inputs + +def add_id(example,index): + example.update({'id':index}) + return example + + + +def prep(raw_ds,fname): + if True: + column_names = raw_ds.column_names + global dataset_name + dataset_name = fname + train_dataset = raw_ds.map( + preprocess_function, + batched=True, + num_proc=4, + remove_columns=column_names, + load_from_cache_file=True, + desc="Running tokenizer on train dataset", + ) + train_dataset = train_dataset.add_column("id",range(len(train_dataset))) + train_dataset.save_to_disk("./ours/{}-train.hf".format(fname)) + + + + + + + +import copy +def prep_valid(raw_ds,fname): + global dataset_name + dataset_name = fname + eval_examples = copy.deepcopy(raw_ds) + eval_examples = chuli_example(eval_examples,fname) + column_names = raw_ds.column_names + if 'id' not in eval_examples.features.keys(): + eval_examples = eval_examples.map(add_id,with_indices=True) + + + if True: + eval_dataset = raw_ds.map(add_id,with_indices=True) + eval_dataset = eval_dataset.map( + preprocess_function_valid, + batched=True, + num_proc=4, + remove_columns=column_names, + load_from_cache_file=True, + desc="Running tokenizer on validation dataset", + ) + eval_dataset.save_to_disk("./ours/{}-eval.hf".format(fname)) + eval_examples.save_to_disk("./ours/{}-evalex.hf".format(fname)) + + +from collections import Counter +import json + + +def main(): + for item in ["wikisql","squad","srl","sst","woz.en","iwslt.en.de","cnn_dailymail","multinli.in.out","zre"]: + print("current_dataset:",item) + print("Loading......") + data_files = {} + + data_files["validation"] = item+"-eval.jsonl" + test_dataset = load_dataset("json", data_files=data_files)["validation"] + if not item in ["cnn_dailymail","multinli.in.out","zre"]: + data_files = {} + data_files["train"] = item+"-train.jsonl" + train_dataset = load_dataset("json", data_files=data_files)["train"] + + + + print("Loaded.") + if not item in ["cnn_dailymail","multinli.in.out","zre"]: + prep(train_dataset,item) + print("preped") + prep_valid(test_dataset,item) + print("valid prepred") + + + +main() \ No newline at end of file diff --git a/diana/metrics.py b/diana/metrics.py new file mode 100644 index 0000000..746136c --- /dev/null +++ b/diana/metrics.py @@ -0,0 +1,477 @@ +import collections +import string +import re +import numpy as np +import json +AGG_OPS = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] +COND_OPS = ['=', '>', '<', 'OP'] +COND_OPS = ['=', '>', '<'] +#from pyrouge import Rouge155 +#from multiprocessing import Pool, cpu_count +#from contextlib import closing +from datasets import load_metric + +def compute_rouge_scores(summs, refs, splitchar='.', options=None, parallel=True): + assert len(summs) == len(refs) + options = [ + '-a', # evaluate all systems + '-c', 95, # confidence interval + '-m', # use Porter stemmer + '-n', 2, # max-ngram + '-w', 1.3, # weight (weighting factor for WLCS) + ] + rr = Rouge(options=options) + rouge_args = [] + for summ, ref in zip(summs, refs): + letter = "A" + ref_dict = {} + for r in ref: + ref_dict[letter] = [x for x in split_sentences(r, splitchar) if len(x) > 0] + letter = chr(ord(letter) + 1) + s = [x for x in split_sentences(summ, splitchar) if len(x) > 0] + rouge_args.append((s, ref_dict)) + if parallel: + with closing(Pool(cpu_count()//2)) as pool: + rouge_scores = pool.starmap(rr.score_summary, rouge_args) + else: + rouge_scores = [] + for s, a in rouge_args: + rouge_scores.append(rr.score_summary(s, ref_dict)) + return rouge_scores + +def computeROUGE(greedy, answer): + rouges = compute_rouge_scores(greedy, answer) + if len(rouges) > 0: + avg_rouges = {} + for key in rouges[0].keys(): + avg_rouges[key] = sum( + [r.get(key, 0.0) for r in rouges]) / len(rouges) * 100 + else: + avg_rouges = None + return avg_rouges + +def lcsstr(string1, string2): + answer = 0 + len1, len2 = len(string1), len(string2) + for i in range(len1): + match = 0 + for j in range(len2): + if (i + j < len1 and string1[i + j] == string2[j]): + match += 1 + if match > answer: + answer = match + else: + match = 0 + return answer + + +def lcs(X, Y): + m = len(X) + n = len(Y) + L = [[None]*(n + 1) for i in range(m + 1)] + + for i in range(m + 1): + for j in range(n + 1): + if i == 0 or j == 0 : + L[i][j] = 0 + elif X[i-1] == Y[j-1]: + L[i][j] = L[i-1][j-1]+1 + else: + L[i][j] = max(L[i-1][j], L[i][j-1]) + return L[m][n] / max(m, n) + lcsstr(X, Y) / 1e4 - min(m, n) / 1e8 + + +def normalize_text(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + def white_space_fix(text): + # print(text) + # print("after:",' '.join(text.split( ))) + return ' '.join(text.split( )) + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + def lower(text): + return text.lower() + # print(white_space_fix(remove_articles(remove_punc(lower(s))))) + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def to_delta_state(line): + delta_state = {'inform': {}, 'request': {}} + try: + if line.lower() == 'none' or line.strip() == '' or line.strip() == ';': + return delta_state + inform, request = [[y.strip() for y in x.strip().split(',')] for x in line.split(';')] + inform_pairs = {} + for i in inform: + try: + k, v = i.split(':') + inform_pairs[k.strip()] = v.strip() + except: + pass + delta_state = {'inform': inform_pairs, 'request': request} + except: + pass + finally: + return delta_state + + +def update_state(state, delta): + for act, slot in delta.items(): + state[act] = slot + return state + + +def dict_cmp(d1, d2): + def cmp(a, b): + for k1, v1 in a.items(): + if k1 not in b: + return False + else: + if v1 != b[k1]: + return False + return True + return cmp(d1, d2) and cmp(d2, d1) + + +def to_lf(s, table): + # print("input:",s) + aggs = [y.lower() for y in AGG_OPS] + agg_to_idx = {x: i for i, x in enumerate(aggs)} + conditionals = [y.lower() for y in COND_OPS] + headers_unsorted = [(y.lower(), i) for i, y in enumerate(table['header'])] + headers = [(y.lower(), i) for i, y in enumerate(table['header'])] + headers.sort(reverse=True, key=lambda x: len(x[0])) + condition_s, conds = None, [] + if 'where' in s: + s, condition_s = s.split('where', 1) + # print("s,cond",s,condition_s) + s = ' '.join(s.split()[1:-2]) + s_no_agg = ' '.join(s.split()[1:]) + # print(s,s_no_agg,"agg") + sel, agg = None, 0 + lcss, idxs = [], [] + for col, idx in headers: + lcss.append(lcs(col, s)) + lcss.append(lcs(col, s_no_agg)) + idxs.append(idx) + lcss = np.array(lcss) + max_id = np.argmax(lcss) + sel = idxs[max_id // 2] + if max_id % 2 == 1: # with agg + agg = agg_to_idx[s.split()[0]] + + full_conditions = [] + if not condition_s is None: + + pattern = '|'.join(COND_OPS) + split_conds_raw = re.split(pattern, condition_s) + split_conds_raw = [conds.strip() for conds in split_conds_raw] + split_conds = [split_conds_raw[0]] + for i in range(1, len(split_conds_raw)-1): + split_conds.extend(re.split('and', split_conds_raw[i])) + split_conds += [split_conds_raw[-1]] + for i in range(0, len(split_conds), 2): + cur_s = split_conds[i] + lcss = [] + for col in headers: + lcss.append(lcs(col[0], cur_s)) + max_id = np.argmax(np.array(lcss)) + split_conds[i] = headers[max_id][0] + for i, m in enumerate(re.finditer(pattern, condition_s)): + split_conds[2*i] = split_conds[2*i] + ' ' + m.group() + split_conds = [' '.join(split_conds[2*i:2*i+2]) for i in range(len(split_conds)//2)] + + condition_s = ' and '.join(split_conds) + condition_s = ' ' + condition_s + ' ' + for idx, col in enumerate(headers): + condition_s = condition_s.replace(' ' + col[0] + ' ', ' Col{} '.format(col[1])) + condition_s = condition_s.strip() + + for idx, col in enumerate(conditionals): + new_s = [] + for t in condition_s.split(): + if t == col: + new_s.append('Cond{}'.format(idx)) + else: + new_s.append(t) + condition_s = ' '.join(new_s) + s = condition_s + + conds = re.split('(Col\d+ Cond\d+)', s) + if len(conds) == 0: + conds = [s] + conds = [x for x in conds if len(x.strip()) > 0] + full_conditions = [] + for i, x in enumerate(conds): + if i % 2 == 0: + x = x.split() + col_num = int(x[0].replace('Col', '')) + opp_num = int(x[1].replace('Cond', '')) + full_conditions.append([col_num, opp_num]) + else: + x = x.split() + if x[-1] == 'and': + x = x[:-1] + x = ' '.join(x) + if 'Col' in x: + new_x = [] + for t in x.split(): + if 'Col' in t: + idx = int(t.replace('Col', '')) + t = headers_unsorted[idx][0] + new_x.append(t) + x = new_x + x = ' '.join(x) + if 'Cond' in x: + new_x = [] + for t in x.split(): + if 'Cond' in t: + idx = int(t.replace('Cond', '')) + t = conditionals[idx] + new_x.append(t) + x = new_x + x = ' '.join(x) + full_conditions[-1].append(x.replace(' ', '')) + logical_form = {'sel': sel, 'conds': full_conditions, 'agg': agg} + return logical_form + + +def computeLFEM(greedy, answer): + count = 0 + correct = 0 + text_answers = [] + for idx, (g, ex) in enumerate(zip(greedy, answer)): + count += 1 + + text_answers.append([ex['answer'].lower()]) + try: + gt = ex['sql'] + conds = gt['conds'] + lower_conds = [] + for c in conds: + lc = c + lc[2] = str(lc[2]).lower().replace(' ', '') + lower_conds.append(lc) + gt['conds'] = lower_conds + # print("gt_answer:",ex['answer'].lower()) + lf = to_lf(g, ex['table']) + # print(lf,"lf") + # print(gt,"gt") + correct += lf == gt + + except Exception as e: + continue + return (correct / count) * 100, text_answers + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def f1_score(prediction, ground_truth): + prediction_tokens = prediction.split() + ground_truth_tokens = ground_truth.split() + common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match(prediction, ground_truth): + return prediction == ground_truth + + +def computeF1(outputs, targets): + return sum([metric_max_over_ground_truths(f1_score, o, t) for o, t in zip(outputs, targets)]) / len(outputs) * 100 + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + scores_for_ground_truths = [] + for idx, ground_truth in enumerate(ground_truths): + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def computeEM(outputs, targets): + outs = [metric_max_over_ground_truths(exact_match, o, t) for o, t in zip(outputs, targets)] + return sum(outs) / len(outputs) * 100 + + +def fix_buggy_characters(str): + return re.sub("[{}^\\\\`\u2047<]", " ", str) +def replace_punctuation(str): + return str.replace("\"", "").replace("'", "") +def score_string_similarity(str1, str2): + if str1 == str2: + return 3.0 # Better than perfect token match + str1 = fix_buggy_characters(replace_punctuation(str1)) + str2 = fix_buggy_characters(replace_punctuation(str2)) + if str1 == str2: + return 2.0 + if " " in str1 or " " in str2: + str1_split = str1.split(" ") + str2_split = str2.split(" ") + overlap = list(set(str1_split) & set(str2_split)) + return len(overlap) / max(len(str1_split), len(str2_split)) + else: + if str1 == str2: + return 1.0 + else: + return 0.0 + +def computeEMtri(outputs, targets): + options = ["entailment","neutral","contradiction"] + preds = [] + for o in outputs: + scores = [score_string_similarity(opt, o) for opt in options] + max_idx = np.argmax(scores) + preds.append(options[max_idx]) + print(preds) + print(targets) + outs = [metric_max_over_ground_truths(exact_match, o, t) for o, t in zip(preds, targets)] + return sum(outs) / len(outputs) * 100 + +def score(answer, gold): + if len(gold) > 0: + gold = set.union(*[simplify(g) for g in gold]) + answer = simplify(answer) + tp, tn, sys_pos, real_pos = 0, 0, 0, 0 + if answer == gold: + if not ('unanswerable' in gold and len(gold) == 1): + tp += 1 + else: + tn += 1 + if not ('unanswerable' in answer and len(answer) == 1): + sys_pos += 1 + if not ('unanswerable' in gold and len(gold) == 1): + real_pos += 1 + return np.array([tp, tn, sys_pos, real_pos]) + + +def simplify(answer): + return set(''.join(c for c in t if c not in string.punctuation) for t in answer.strip().lower().split()) - {'the', 'a', 'an', 'and', ''} + + +def computeCF1(greedy, answer): + scores = np.zeros(4) + for g, a in zip(greedy, answer): + scores += score(g, a) + tp, tn, sys_pos, real_pos = scores.tolist() + total = len(answer) + if tp == 0: + p = r = f = 0.0 + else: + p = tp / float(sys_pos) + r = tp / float(real_pos) + f = 2 * p * r / (p + r) + return f * 100, p * 100, r * 100 + + +def computeDialogue(greedy, answer): + examples = [] + for idx, (g, a) in enumerate(zip(greedy, answer)): + examples.append((a[0], g, a[1], idx)) + #examples.sort() + turn_request_positives = 0 + turn_goal_positives = 0 + joint_goal_positives = 0 + ldt = None + for ex in examples: + if ldt is None or ldt.split('_')[:-1] != ex[0].split('_')[:-1]: + state, answer_state = {}, {} + ldt = ex[0] + delta_state = to_delta_state(ex[1]) + answer_delta_state = to_delta_state(ex[2]) + state = update_state(state, delta_state['inform']) + answer_state = update_state(answer_state, answer_delta_state['inform']) + if dict_cmp(state, answer_state): + joint_goal_positives += 1 + if delta_state['request'] == answer_delta_state['request']: + turn_request_positives += 1 + if dict_cmp(delta_state['inform'], answer_delta_state['inform']): + turn_goal_positives += 1 + + joint_goal_em = joint_goal_positives / len(examples) * 100 + turn_request_em = turn_request_positives / len(examples) * 100 + turn_goal_em = turn_goal_positives / len(examples) * 100 + answer = [(x[-1], x[-2]) for x in examples] + #answer.sort() + answer = [[x[1]] for x in answer] + return joint_goal_em, turn_request_em, turn_goal_em, answer + + +def compute_metrics(data, rouge=False, bleu=False, corpus_f1=False, logical_form=False, dialogue=False,tri=False): + if rouge: + metric_func = load_metric("metric/rouge_local/rouge_metric.py") + metrics = metric_func.compute(predictions=data.predictions, references=data.label_ids) + metric_keys = ["rougeL"] + metric_values = metrics["rougeL"] + metric_dict = collections.OrderedDict(list(zip(metric_keys, [metric_values]))) + return metric_dict + + greedy = data.predictions + answer = data.label_ids + greedy = [_["prediction_text"] for _ in greedy] + # greedy = [_["answers"]["text"][0].lower() for _ in answer] + answer = [_["answers"]["text"] for _ in answer] + + + if dialogue: + + addition_answers = open("./oursdeca/"+"woz.en_answers.json",'r') + answer = json.load(addition_answers) + if logical_form: + addition_answers = open("./oursdeca/"+"wikisql_answers.json",'r') + answer = json.load(addition_answers) + + + + metric_keys = [] + metric_values = [] + if logical_form: + lfem, answer = computeLFEM(greedy, answer) + metric_keys += ['lfem'] + metric_values += [lfem] + if tri: + em = computeEMtri(greedy,answer) + else: + em = computeEM(greedy, answer) + print(greedy[:20]) + print(answer[:20]) + metric_keys.append('em') + metric_values.append(em) + norm_greedy = [normalize_text(g) for g in greedy] + norm_answer = [[normalize_text(a) for a in ans] for ans in answer] + + nf1 = computeF1(norm_greedy, norm_answer) + nem = computeEM(norm_greedy, norm_answer) + metric_keys.extend(['nf1', 'nem']) + metric_values.extend([nf1, nem]) + if rouge: + rouge = computeROUGE(greedy, answer) + metric_keys += ['rouge1', 'rouge2', 'rougeL', 'avg_rouge'] + avg_rouge = (rouge['rouge_1_f_score'] + rouge['rouge_2_f_score'] + rouge['rouge_l_f_score']) / 3 + metric_values += [rouge['rouge_1_f_score'], rouge['rouge_2_f_score'], rouge['rouge_l_f_score'], avg_rouge] + if corpus_f1: + corpus_f1, precision, recall = computeCF1(norm_greedy, norm_answer) + metric_keys += ['corpus_f1', 'precision', 'recall'] + metric_values += [corpus_f1, precision, recall] + if dialogue: + joint_goal_em, request_em, turn_goal_em, answer = computeDialogue(greedy, answer) + avg_dialogue = (joint_goal_em + request_em) / 2 + metric_keys += ['joint_goal_em', 'turn_request_em', 'turn_goal_em', 'avg_dialogue'] + metric_values += [joint_goal_em, request_em, turn_goal_em, avg_dialogue] + metric_dict = collections.OrderedDict(list(zip(metric_keys, metric_values))) + return metric_dict \ No newline at end of file diff --git a/diana/models/metadeca.py b/diana/models/metadeca.py new file mode 100755 index 0000000..c0fe796 --- /dev/null +++ b/diana/models/metadeca.py @@ -0,0 +1,2406 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model. """ + + +import copy +import math +import os +import warnings +import numpy as np +from random import random +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint +from .utils import * +from transformers.activations import ACT2FN +from transformers.file_utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + replace_return_docstrings, +) +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.models.t5.configuration_t5 import T5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_TOKENIZER_FOR_DOC = "T5Tokenizer" +_CHECKPOINT_FOR_DOC = "t5-small" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### +T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + # See all T5 models at https://huggingface.co/models?filter=t5 +] + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (:obj:`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the t5 models have the + following number of attention modules: + + - t5-small: 6 + - t5-base: 12 + - t5-large: 24 + - t5-3b: 24 + - t5-11b: 24 + + Example:: + + # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules: + model = T5ForConditionalGeneration.from_pretrained('t5-3b') + device_map = {0: [0, 1, 2], + + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23]} + model.parallelize(device_map) +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example:: + + # On a 4 GPU machine with t5-3b: + model = T5ForConditionalGeneration.from_pretrained('t5-3b') + device_map = {0: [0, 1, 2], + + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23]} + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() +""" + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # layer norm should always be calculated in float32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into float16 if necessary + if self.weight.dtype == torch.float16: + hidden_states = hidden_states.to(torch.float16) + return self.weight * hidden_states + + +class T5DenseReluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = nn.functional.relu(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedGeluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.gelu_act = ACT2FN["gelu_new"] + + def forward(self, hidden_states): + hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config): + super().__init__() + if config.feed_forward_proj == "relu": + self.DenseReluDense = T5DenseReluDense(config) + elif config.feed_forward_proj == "gated-gelu": + self.DenseReluDense = T5DenseGatedGeluDense(config) + else: + raise ValueError( + f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`" + ) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__(self, config: T5Config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False) + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + int_seq_length = int(seq_length) + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.training and self.gradient_checkpointing: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -int_seq_length:, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + scores += position_bias + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + + if past_key_value is not None: + assert self.is_decoder, "Only decoder can use `past_key_values`" + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}." + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + is_parallelizable = True + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, T5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, T5DenseReluDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5DenseGatedGeluDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert ( + decoder_start_token_id is not None + ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" + + return shifted_input_ids + + +class T5Stack(T5PreTrainedModel): + def __init__(self, config, embed_tokens=None,name=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + self.name = name + self.block = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + if (self.name=="encoder"): + self.block2 = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm2 = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout2 = nn.Dropout(config.dropout_rate) + self.gen_mode = False + + + if(self.name=="encoder"): + self.criterion = BoundaryLoss() + self.prompt_num=20 + self.start_step = 0 + self.format_num = 3 + self.task_num = 10 + self.mm_ids = None + self.mm_querys = None + self.memory_querys = [] + self.memory_ids = [] + self.format2general={0:8,1:6,2:7} + self.general_num = 1 + self.domain_num = 30 + self.domain_size = 20 + self.buffers = [-1.0]*self.domain_num + self.buffer_ids = [0]*self.domain_num + self.vectors = [np.array([0.0]*config.d_model)]*self.domain_num + self.total = 0 + self.right = 0 + self.wrong = 0 + self.train_count =0 + task_key_size = (self.task_num,config.d_model) + domain_key_size = (self.domain_num,config.d_model) + self.task_keys = nn.Parameter(torch.ones(task_key_size)) + nn.init.xavier_uniform_(self.task_keys) + self.prompt_embeddings = nn.Embedding(1600, config.d_model) + self.domain_keys = nn.Parameter(torch.ones(domain_key_size)) + nn.init.xavier_uniform_(self.domain_keys) + + self.init_weights() + # Model parallel + self.model_parallel = False + self.device_map = None + + + + + + + def reset_train_count(self): + self.train_count = 0 + + def count_clear(self): + self.total = 0 + self.right = 0 + self.wrong = 0 + + def get_max_keys(self,querys,keys,ids): + domain_sims = torch.matmul(querys,keys.t()) + top5 = torch.topk(domain_sims,5).indices + + _,indices = torch.sort(domain_sims,dim=1,descending=True) + indices = indices.detach().cpu().numpy().tolist() + + if ids is not None: #only training + self.start_step-=1 + + if self.start_step<=0: + + device_index = ids.device.index + + + ids_ = ids.detach().cpu().numpy().tolist() + tmp = self.buffer_ids + for idx,(sub_indices,idd) in enumerate(zip(indices,ids_)): + if idd in tmp: + continue + + for choice in sub_indices: + if domain_sims[idx][choice].item()>self.buffers[choice]: + self.buffers[choice]=domain_sims[idx][choice].item() + self.buffer_ids[choice]=idd + self.vectors[choice]=querys[idx].detach().cpu().numpy() + + + break + + return top5 + + def get_mmr_keys(self,querys,keys,ids): + + domain_sims = torch.matmul(querys,keys.t()) + top1 = torch.topk(domain_sims,1).indices + _,indices = torch.sort(domain_sims,dim=1,descending=True) + indices = indices.detach().cpu().numpy().tolist() + selected = top1.detach().cpu().numpy().tolist() + if ids is not None: + self.start_step-=1 + + if self.start_step<=0: + + device_index = ids.device.index + ids_ = ids.detach().cpu().numpy().tolist() + tmp = self.buffer_ids + for idx,(sub_indices,idd) in enumerate(zip(indices,ids_)): + if idd in tmp: + continue + for choice in sub_indices: + if domain_sims[idx][choice].item()>self.buffers[choice]: + self.buffers[choice]=domain_sims[idx][choice].item() + self.buffer_ids[choice]=idd + self.vectors[choice]=querys[idx].detach().cpu().numpy() + break + + assert len(top1.size())==2 + self_sims = torch.matmul(keys,keys.t()) + + self_sims = self_sims.detach().cpu().numpy() + domain_sims = domain_sims.detach().cpu().numpy() + all_selected = [] + for line in range(domain_sims.shape[0]): + already = selected[line] + domain_sims_line = domain_sims[line] + for idx in range(4): + sims_from = self_sims[already] + max_sims = np.max(sims_from,axis=0) + domain_aggregate = domain_sims_line*0.6-max_sims*0.4 + ids = np.argsort(-domain_aggregate).tolist() + for item in ids: + if (not item in already): + already.append(item) + break + all_selected.append(already) + # print(all_selected) + all_selected=torch.LongTensor(np.array(all_selected)).to(keys.device) + return all_selected + + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + self.block2[layer] = self.block2[leyer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + self.final_layer_norm2 = self.final_layer_norm2.to(self.last_device) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.block2[i] = self.block2[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + self.final_layer_norm2 = self.final_layer_norm2.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def get_query_vector( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None + ): + assert (self.name=="encoder") +################## + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + assert (inputs_embeds is None and input_ids is not None) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + + inputs_embeds = self.embed_tokens(input_ids * (input_ids >= 0).int()) * (input_ids >= 0).int().unsqueeze(2) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block2) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout2(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block2, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + + position_bias = layer_outputs[2] + + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm2(hidden_states) + hidden_states = self.dropout2(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + assert False + + outs = BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + hidden_query_states = outs[0] + query_vector = torch.mean(hidden_query_states,dim=1) + + return query_vector + def add_negs(self,ids,vectors): + self.memory_ids.extend(ids) + self.memory_querys.extend(vectors) + self.mm_ids = torch.Tensor(self.memory_ids)#.to(self.prompt_embeddings.device) + + self.mm_querys = torch.Tensor(np.array(self.memory_querys))#.to(self.prompt_embeddings.device) + + + def forward( + self, + input_ids=None, + attention_mask=None, + id=None, + prototype_vectors = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + train=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_prompt_recored = None, + ): + + if self.name=="encoder": + + query_vector = self.get_query_vector( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + else: + query_vector = None + + + allkey_loss = None + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + if True: + if (self.name=="encoder"): + + #add prompt + query_vector = query_vector/torch.norm(query_vector, p=2, dim=1, keepdim=True) + + tk = self.task_keys/torch.norm(self.task_keys,p=2,dim=1,keepdim=True) + + dk = self.domain_keys/torch.norm(self.domain_keys,p=2,dim=1,keepdim=True) + + query_vector[torch.isnan(query_vector)] = 0 + tk[torch.isnan(tk)]=0 + dk[torch.isnan(dk)]=0 + + if prototype_vectors is not None: + prototype_vectors = prototype_vectors/torch.norm(prototype_vectors,p=2,dim=1,keepdim=True) + prototype_vectors[torch.isnan(prototype_vectors)] = 0 + sims_prototype = torch.matmul(prototype_vectors,dk.t()) + + guess_domain_key_indexs = self.get_mmr_keys(query_vector,dk,ids=id) + + sims_domain = torch.matmul(query_vector,dk.t()) + sims = torch.matmul(query_vector,tk.t()) + guess_task_key_indexs = torch.topk(sims,1).indices.squeeze() + if (len(guess_task_key_indexs.size())==0): + guess_task_key_indexs = guess_task_key_indexs.unsqueeze(0) + + + inputs_embeds = self.embed_tokens(input_ids * (input_ids >= 0).int()) * (input_ids >= 0).int().unsqueeze(2) + prompt_ids = - (input_ids + 1) + format_id = (prompt_ids[0][10].item()-300)//10 + task_id = prompt_ids[:,20]//self.prompt_num + + if train is None: + allkey_loss = None + disp_domain = guess_domain_key_indexs.repeat_interleave(self.domain_size,dim=1)*self.domain_size + + prompt_ids[:,self.prompt_num*2:self.prompt_num*2+5*self.domain_size]+=disp_domain + delta_task = guess_task_key_indexs + + centro = self.task_keys/torch.norm(self.task_keys, p=2, dim=1, keepdim=True) + logits = cosine_metric(query_vector,centro) + probs, preds = F.softmax(logits.detach(), dim = 1).max(dim = 1) + cos_distance = 1-(query_vector*centro[preds]).sum(dim=1) + + preds_ = copy.deepcopy(preds.detach()) + try: + preds[cos_distance >= self.delta[preds]] = self.format2general[format_id] + except: + pass + guess_task_key_indexs = preds.squeeze() + if (len(guess_task_key_indexs.size())==0): + guess_task_key_indexs = guess_task_key_indexs.unsqueeze(0) + # print(preds) + delta_task = delta_task.unsqueeze(1).repeat(repeats=(1,20))*self.prompt_num + prompt_ids[:,self.prompt_num:self.prompt_num*2]+=delta_task + + else: + self.train_count+=3e-4 + _r = random() + if _r>0.95: + label_id = copy.deepcopy(task_id) + label_id = label_id.squeeze() + if (len(label_id.size())==0): + label_id = label_id.unsqueeze(0) + allkey_loss,self.delta = self.criterion(query_vector,self.task_keys,label_id) + + else: + if _r= 0).int()) * (prompt_ids >= 0).int().unsqueeze(2) + + inputs_embeds = inputs_embeds + prompt_emb + elif(self.name=="decoder"): + inputs_embeds = self.embed_tokens(input_ids) + + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + assert self.is_decoder, f":obj:`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + if allkey_loss is not None: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ),allkey_loss + else: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +T5_START_DOCSTRING = r""" + + The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer + `__ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, + Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text + denoising generative setting. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + detail. + + `What are input IDs? <../glossary.html#input-ids>`__ + + To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training + <./t5.html#training>`__. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__ + + T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If + :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see + :obj:`past_key_values`). + + To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training + <./t5.html#training>`__. + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): + Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: + `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a + sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded + representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` + have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert + :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` + takes the value of :obj:`inputs_embeds`. + + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +T5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + detail. + + To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training + <./t5.html#training>`__. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", + T5_START_DOCSTRING, +) +class T5Model(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import T5Tokenizer, T5Model + + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5Model.from_pretrained('t5-small') + + >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + >>> last_hidden_states = outputs.last_hidden_state + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +#dingwei2 + + + +@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) +class T5ForConditionalGeneration(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + r"lm_head\.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + + def __init__(self, config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared,"encoder") + # self.query_encoder = T5Stack(encoder_config, self.shared,"encoder") + #self.query_pooler = + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + + self.decoder = T5Stack(decoder_config, self.shared,"decoder") + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gen_mode = False + + + def freeze_all(self): + self.gen_mode = True + self.encoder.gen_mode = True + for param in self.encoder.parameters(): + param.requires_grad = False + for param in self.decoder.parameters(): + param.requires_grad = False + for param in self.encoder.prompt_embeddings.parameters(): + param.requires_grad = True + self.encoder.task_keys.requires_grad = True + + def unfreeze_all(self): + self.gen_mode = False + self.encoder.gen_mode = False + for param in self.encoder.parameters(): + param.requires_grad = True + for param in self.decoder.parameters(): + param.requires_grad = True + for param in self.encoder.block2.parameters(): + param.requires_grad=False + for param in self.encoder.final_layer_norm2.parameters(): + param.requires_grad=False + for param in self.encoder.dropout2.parameters(): + param.requires_grad=False + + def copy_encoder(self): + self.encoder.block2= copy.deepcopy(self.encoder.block) + self.encoder.final_layer_norm2 = copy.deepcopy(self.encoder.final_layer_norm) + self.encoder.dropout2 = copy.deepcopy(self.encoder.dropout) + for param in self.encoder.block2.parameters(): + param.requires_grad=False + for param in self.encoder.final_layer_norm2.parameters(): + param.requires_grad=False + for param in self.encoder.dropout2.parameters(): + param.requires_grad=False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + id=None, + prototype_vectors=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + train=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_prompt_recored=None, + ): + + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., + config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for + labels in ``[0, ..., config.vocab_size]`` + + Returns: + + Examples:: + + >>> from transformers import T5Tokenizer, T5ForConditionalGeneration + + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5ForConditionalGeneration.from_pretrained('t5-small') + + >>> input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids + >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # print("end use query") + # hidden_query_states = query_outputs[0] + + #last_hidden_state + + + sim_loss = None + + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + if labels is not None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task_prompt_recored=task_prompt_recored, + train=train, + ) + else: + encoder_outputs,sim_loss = self.encoder( + input_ids=input_ids, + id = id, + prototype_vectors = prototype_vectors, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task_prompt_recored=task_prompt_recored, + train=train, + ) + + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + + hidden_states = encoder_outputs[0] + + + + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # If decoding with past key value states, only the last tokens + # should be given as an input + if past_key_values is not None: + assert labels is None, "Decoder should not use cached key value states when training." + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_inputs_embeds is not None: + decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim ** -0.5) + + lm_logits = self.lm_head(sequence_output) + # print("get_lm_logits") + loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + + + + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # task similarity loss + # print(loss,'\n') + # print("get_sim_loss") + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + if sim_loss is not None: + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + + ),sim_loss + else: + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + if "task_prompt_recored" in kwargs: + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + "task_prompt_recored": kwargs["task_prompt_recored"] + + } + else: + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class T5EncoderModel(T5PreTrainedModel): + authorized_missing_keys = [ + r"encoder\.embed_tokens\.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import T5Tokenizer, T5EncoderModel + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5EncoderModel.from_pretrained('t5-small') + >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs diff --git a/diana/models/metadecanometa.py b/diana/models/metadecanometa.py new file mode 100755 index 0000000..526892a --- /dev/null +++ b/diana/models/metadecanometa.py @@ -0,0 +1,2386 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model. """ + + +import copy +import math +import os +import warnings +import numpy as np +from random import random +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint +from .utils import * +from transformers.activations import ACT2FN +from transformers.file_utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + replace_return_docstrings, +) +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.models.t5.configuration_t5 import T5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_TOKENIZER_FOR_DOC = "T5Tokenizer" +_CHECKPOINT_FOR_DOC = "t5-small" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### +T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + # See all T5 models at https://huggingface.co/models?filter=t5 +] + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (:obj:`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the t5 models have the + following number of attention modules: + + - t5-small: 6 + - t5-base: 12 + - t5-large: 24 + - t5-3b: 24 + - t5-11b: 24 + + Example:: + + # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules: + model = T5ForConditionalGeneration.from_pretrained('t5-3b') + device_map = {0: [0, 1, 2], + + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23]} + model.parallelize(device_map) +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example:: + + # On a 4 GPU machine with t5-3b: + model = T5ForConditionalGeneration.from_pretrained('t5-3b') + device_map = {0: [0, 1, 2], + + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23]} + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() +""" + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # layer norm should always be calculated in float32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into float16 if necessary + if self.weight.dtype == torch.float16: + hidden_states = hidden_states.to(torch.float16) + return self.weight * hidden_states + + +class T5DenseReluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = nn.functional.relu(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedGeluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.gelu_act = ACT2FN["gelu_new"] + + def forward(self, hidden_states): + hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config): + super().__init__() + if config.feed_forward_proj == "relu": + self.DenseReluDense = T5DenseReluDense(config) + elif config.feed_forward_proj == "gated-gelu": + self.DenseReluDense = T5DenseGatedGeluDense(config) + else: + raise ValueError( + f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`" + ) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__(self, config: T5Config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False) + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + int_seq_length = int(seq_length) + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.training and self.gradient_checkpointing: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -int_seq_length:, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + scores += position_bias + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + + if past_key_value is not None: + assert self.is_decoder, "Only decoder can use `past_key_values`" + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}." + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + is_parallelizable = True + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, T5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, T5DenseReluDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5DenseGatedGeluDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert ( + decoder_start_token_id is not None + ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" + + return shifted_input_ids + + +class T5Stack(T5PreTrainedModel): + def __init__(self, config, embed_tokens=None,name=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + self.name = name + self.block = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + if (self.name=="encoder"): + self.block2 = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm2 = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout2 = nn.Dropout(config.dropout_rate) + self.gen_mode = False + + + if(self.name=="encoder"): + self.criterion = BoundaryLoss() + self.prompt_num=20 + self.start_step = 0 + self.format_num = 3 + self.task_num = 10 + self.mm_ids = None + self.mm_querys = None + self.memory_querys = [] + self.memory_ids = [] + self.format2general={0:8,1:6,2:7} + self.general_num = 1 + self.domain_num = 30 + self.domain_size = 20 + self.buffers = [-1.0]*self.domain_num + self.buffer_ids = [0]*self.domain_num + self.vectors = [np.array([0.0]*config.d_model)]*self.domain_num + self.total = 0 + self.right = 0 + self.wrong = 0 + self.train_count =0 + task_key_size = (self.task_num,config.d_model) + domain_key_size = (self.domain_num,config.d_model) + self.task_keys = nn.Parameter(torch.ones(task_key_size)) + nn.init.xavier_uniform_(self.task_keys) + self.prompt_embeddings = nn.Embedding(1600, config.d_model) + self.domain_keys = nn.Parameter(torch.ones(domain_key_size)) + nn.init.xavier_uniform_(self.domain_keys) + + self.init_weights() + # Model parallel + self.model_parallel = False + self.device_map = None + + + + + + + def reset_train_count(self): + self.train_count = 0 + + def count_clear(self): + self.total = 0 + self.right = 0 + self.wrong = 0 + + def get_max_keys(self,querys,keys,ids): + domain_sims = torch.matmul(querys,keys.t()) + top5 = torch.topk(domain_sims,5).indices + + _,indices = torch.sort(domain_sims,dim=1,descending=True) + indices = indices.detach().cpu().numpy().tolist() + + if ids is not None: #only training + self.start_step-=1 + + if self.start_step<=0: + + device_index = ids.device.index + + + ids_ = ids.detach().cpu().numpy().tolist() + tmp = self.buffer_ids + for idx,(sub_indices,idd) in enumerate(zip(indices,ids_)): + if idd in tmp: + continue + + for choice in sub_indices: + if domain_sims[idx][choice].item()>self.buffers[choice]: + self.buffers[choice]=domain_sims[idx][choice].item() + self.buffer_ids[choice]=idd + self.vectors[choice]=querys[idx].detach().cpu().numpy() + + + break + + return top5 + + def get_mmr_keys(self,querys,keys,ids): + + domain_sims = torch.matmul(querys,keys.t()) + top1 = torch.topk(domain_sims,1).indices + _,indices = torch.sort(domain_sims,dim=1,descending=True) + indices = indices.detach().cpu().numpy().tolist() + selected = top1.detach().cpu().numpy().tolist() + if ids is not None: + self.start_step-=1 + + if self.start_step<=0: + + device_index = ids.device.index + ids_ = ids.detach().cpu().numpy().tolist() + tmp = self.buffer_ids + for idx,(sub_indices,idd) in enumerate(zip(indices,ids_)): + if idd in tmp: + continue + for choice in sub_indices: + if domain_sims[idx][choice].item()>self.buffers[choice]: + self.buffers[choice]=domain_sims[idx][choice].item() + self.buffer_ids[choice]=idd + self.vectors[choice]=querys[idx].detach().cpu().numpy() + break + + assert len(top1.size())==2 + self_sims = torch.matmul(keys,keys.t()) + + self_sims = self_sims.detach().cpu().numpy() + domain_sims = domain_sims.detach().cpu().numpy() + all_selected = [] + for line in range(domain_sims.shape[0]): + already = selected[line] + domain_sims_line = domain_sims[line] + for idx in range(4): + sims_from = self_sims[already] + max_sims = np.max(sims_from,axis=0) + domain_aggregate = domain_sims_line*0.6-max_sims*0.4 + ids = np.argsort(-domain_aggregate).tolist() + for item in ids: + if (not item in already): + already.append(item) + break + all_selected.append(already) + # print(all_selected) + all_selected=torch.LongTensor(np.array(all_selected)).to(keys.device) + return all_selected + + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + self.block2[layer] = self.block2[leyer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + self.final_layer_norm2 = self.final_layer_norm2.to(self.last_device) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.block2[i] = self.block2[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + self.final_layer_norm2 = self.final_layer_norm2.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def get_query_vector( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None + ): + assert (self.name=="encoder") +################## + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + assert (inputs_embeds is None and input_ids is not None) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + + inputs_embeds = self.embed_tokens(input_ids * (input_ids >= 0).int()) * (input_ids >= 0).int().unsqueeze(2) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block2) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout2(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block2, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + + position_bias = layer_outputs[2] + + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm2(hidden_states) + hidden_states = self.dropout2(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + assert False + + outs = BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + hidden_query_states = outs[0] + query_vector = torch.mean(hidden_query_states,dim=1) + + return query_vector + def add_negs(self,ids,vectors): + self.memory_ids.extend(ids) + self.memory_querys.extend(vectors) + self.mm_ids = torch.Tensor(self.memory_ids)#.to(self.prompt_embeddings.device) + + self.mm_querys = torch.Tensor(np.array(self.memory_querys))#.to(self.prompt_embeddings.device) + + + def forward( + self, + input_ids=None, + attention_mask=None, + id=None, + prototype_vectors = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + train=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_prompt_recored = None, + ): + + if self.name=="encoder": + + query_vector = self.get_query_vector( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + else: + query_vector = None + + + allkey_loss = None + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + if True: + if (self.name=="encoder"): + + #add prompt + query_vector = query_vector/torch.norm(query_vector, p=2, dim=1, keepdim=True) + + tk = self.task_keys/torch.norm(self.task_keys,p=2,dim=1,keepdim=True) + + dk = self.domain_keys/torch.norm(self.domain_keys,p=2,dim=1,keepdim=True) + + query_vector[torch.isnan(query_vector)] = 0 + tk[torch.isnan(tk)]=0 + dk[torch.isnan(dk)]=0 + + if prototype_vectors is not None: + prototype_vectors = prototype_vectors/torch.norm(prototype_vectors,p=2,dim=1,keepdim=True) + prototype_vectors[torch.isnan(prototype_vectors)] = 0 + sims_prototype = torch.matmul(prototype_vectors,dk.t()) + + guess_domain_key_indexs = self.get_mmr_keys(query_vector,dk,ids=id) + + sims_domain = torch.matmul(query_vector,dk.t()) + sims = torch.matmul(query_vector,tk.t()) + guess_task_key_indexs = torch.topk(sims,1).indices.squeeze() + if (len(guess_task_key_indexs.size())==0): + guess_task_key_indexs = guess_task_key_indexs.unsqueeze(0) + + + inputs_embeds = self.embed_tokens(input_ids * (input_ids >= 0).int()) * (input_ids >= 0).int().unsqueeze(2) + prompt_ids = - (input_ids + 1) + format_id = (prompt_ids[0][10].item()-300)//10 + task_id = prompt_ids[:,20]//self.prompt_num + + if train is None: + allkey_loss = None + + delta_task = guess_task_key_indexs + + centro = self.task_keys/torch.norm(self.task_keys, p=2, dim=1, keepdim=True) + logits = cosine_metric(query_vector,centro) + probs, preds = F.softmax(logits.detach(), dim = 1).max(dim = 1) + cos_distance = 1-(query_vector*centro[preds]).sum(dim=1) + + preds_ = copy.deepcopy(preds.detach()) + try: + preds[cos_distance >= self.delta[preds]] = self.format2general[format_id] + except: + pass + guess_task_key_indexs = preds.squeeze() + if (len(guess_task_key_indexs.size())==0): + guess_task_key_indexs = guess_task_key_indexs.unsqueeze(0) + # print(preds) + delta_task = delta_task.unsqueeze(1).repeat(repeats=(1,20))*self.prompt_num + prompt_ids[:,self.prompt_num:self.prompt_num*2]+=delta_task + + else: + self.train_count+=3e-4 + _r = random() + if _r>0.95: + label_id = copy.deepcopy(task_id) + label_id = label_id.squeeze() + if (len(label_id.size())==0): + label_id = label_id.unsqueeze(0) + allkey_loss,self.delta = self.criterion(query_vector,self.task_keys,label_id) + + else: + if _r= 0).int()) * (prompt_ids >= 0).int().unsqueeze(2) + + inputs_embeds = inputs_embeds + prompt_emb + elif(self.name=="decoder"): + inputs_embeds = self.embed_tokens(input_ids) + + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + assert self.is_decoder, f":obj:`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + if allkey_loss is not None: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ),allkey_loss + else: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +T5_START_DOCSTRING = r""" + + The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer + `__ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, + Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text + denoising generative setting. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + detail. + + `What are input IDs? <../glossary.html#input-ids>`__ + + To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training + <./t5.html#training>`__. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__ + + T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If + :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see + :obj:`past_key_values`). + + To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training + <./t5.html#training>`__. + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): + Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: + `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a + sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded + representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` + have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert + :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` + takes the value of :obj:`inputs_embeds`. + + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +T5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + detail. + + To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training + <./t5.html#training>`__. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", + T5_START_DOCSTRING, +) +class T5Model(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import T5Tokenizer, T5Model + + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5Model.from_pretrained('t5-small') + + >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + >>> last_hidden_states = outputs.last_hidden_state + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +#dingwei2 + + + +@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) +class T5ForConditionalGeneration(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + r"lm_head\.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + + def __init__(self, config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared,"encoder") + # self.query_encoder = T5Stack(encoder_config, self.shared,"encoder") + #self.query_pooler = + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + + self.decoder = T5Stack(decoder_config, self.shared,"decoder") + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gen_mode = False + + + def freeze_all(self): + self.gen_mode = True + self.encoder.gen_mode = True + for param in self.encoder.parameters(): + param.requires_grad = False + for param in self.decoder.parameters(): + param.requires_grad = False + for param in self.encoder.prompt_embeddings.parameters(): + param.requires_grad = True + self.encoder.task_keys.requires_grad = True + + def unfreeze_all(self): + self.gen_mode = False + self.encoder.gen_mode = False + for param in self.encoder.parameters(): + param.requires_grad = True + for param in self.decoder.parameters(): + param.requires_grad = True + for param in self.encoder.block2.parameters(): + param.requires_grad=False + for param in self.encoder.final_layer_norm2.parameters(): + param.requires_grad=False + for param in self.encoder.dropout2.parameters(): + param.requires_grad=False + + def copy_encoder(self): + self.encoder.block2= copy.deepcopy(self.encoder.block) + self.encoder.final_layer_norm2 = copy.deepcopy(self.encoder.final_layer_norm) + self.encoder.dropout2 = copy.deepcopy(self.encoder.dropout) + for param in self.encoder.block2.parameters(): + param.requires_grad=False + for param in self.encoder.final_layer_norm2.parameters(): + param.requires_grad=False + for param in self.encoder.dropout2.parameters(): + param.requires_grad=False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + id=None, + prototype_vectors=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + train=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_prompt_recored=None, + ): + + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., + config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for + labels in ``[0, ..., config.vocab_size]`` + + Returns: + + Examples:: + + >>> from transformers import T5Tokenizer, T5ForConditionalGeneration + + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5ForConditionalGeneration.from_pretrained('t5-small') + + >>> input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids + >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # print("end use query") + # hidden_query_states = query_outputs[0] + + #last_hidden_state + + + sim_loss = None + + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + if labels is not None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task_prompt_recored=task_prompt_recored, + train=train, + ) + else: + encoder_outputs,sim_loss = self.encoder( + input_ids=input_ids, + id = id, + prototype_vectors = prototype_vectors, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task_prompt_recored=task_prompt_recored, + train=train, + ) + + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + + hidden_states = encoder_outputs[0] + + + + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # If decoding with past key value states, only the last tokens + # should be given as an input + if past_key_values is not None: + assert labels is None, "Decoder should not use cached key value states when training." + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_inputs_embeds is not None: + decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim ** -0.5) + + lm_logits = self.lm_head(sequence_output) + # print("get_lm_logits") + loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + + + + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # task similarity loss + # print(loss,'\n') + # print("get_sim_loss") + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + if sim_loss is not None: + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + + ),sim_loss + else: + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + if "task_prompt_recored" in kwargs: + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + "task_prompt_recored": kwargs["task_prompt_recored"] + + } + else: + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class T5EncoderModel(T5PreTrainedModel): + authorized_missing_keys = [ + r"encoder\.embed_tokens\.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import T5Tokenizer, T5EncoderModel + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5EncoderModel.from_pretrained('t5-small') + >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs diff --git a/diana/models/metadecanotask.py b/diana/models/metadecanotask.py new file mode 100755 index 0000000..af3ae1b --- /dev/null +++ b/diana/models/metadecanotask.py @@ -0,0 +1,2349 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model. """ + + +import copy +import math +import os +import warnings +import numpy as np +from random import random +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint +from .utils import * +from transformers.activations import ACT2FN +from transformers.file_utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + replace_return_docstrings, +) +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.models.t5.configuration_t5 import T5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_TOKENIZER_FOR_DOC = "T5Tokenizer" +_CHECKPOINT_FOR_DOC = "t5-small" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### +T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + # See all T5 models at https://huggingface.co/models?filter=t5 +] + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (:obj:`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the t5 models have the + following number of attention modules: + + - t5-small: 6 + - t5-base: 12 + - t5-large: 24 + - t5-3b: 24 + - t5-11b: 24 + + Example:: + + # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules: + model = T5ForConditionalGeneration.from_pretrained('t5-3b') + device_map = {0: [0, 1, 2], + + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23]} + model.parallelize(device_map) +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example:: + + # On a 4 GPU machine with t5-3b: + model = T5ForConditionalGeneration.from_pretrained('t5-3b') + device_map = {0: [0, 1, 2], + + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23]} + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() +""" + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # layer norm should always be calculated in float32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into float16 if necessary + if self.weight.dtype == torch.float16: + hidden_states = hidden_states.to(torch.float16) + return self.weight * hidden_states + + +class T5DenseReluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = nn.functional.relu(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedGeluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.gelu_act = ACT2FN["gelu_new"] + + def forward(self, hidden_states): + hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config): + super().__init__() + if config.feed_forward_proj == "relu": + self.DenseReluDense = T5DenseReluDense(config) + elif config.feed_forward_proj == "gated-gelu": + self.DenseReluDense = T5DenseGatedGeluDense(config) + else: + raise ValueError( + f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`" + ) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__(self, config: T5Config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False) + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + int_seq_length = int(seq_length) + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.training and self.gradient_checkpointing: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -int_seq_length:, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + scores += position_bias + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + + if past_key_value is not None: + assert self.is_decoder, "Only decoder can use `past_key_values`" + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}." + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + is_parallelizable = True + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, T5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, T5DenseReluDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5DenseGatedGeluDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert ( + decoder_start_token_id is not None + ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" + + return shifted_input_ids + + +class T5Stack(T5PreTrainedModel): + def __init__(self, config, embed_tokens=None,name=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + self.name = name + self.block = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + if (self.name=="encoder"): + self.block2 = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm2 = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout2 = nn.Dropout(config.dropout_rate) + self.gen_mode = False + + + if(self.name=="encoder"): + self.criterion = BoundaryLoss() + self.prompt_num=20 + self.start_step = 0 + self.format_num = 3 + self.task_num = 10 + self.mm_ids = None + self.mm_querys = None + self.memory_querys = [] + self.memory_ids = [] + self.format2general={0:8,1:6,2:7} + self.general_num = 1 + self.domain_num = 30 + self.domain_size = 20 + self.buffers = [-1.0]*self.domain_num + self.buffer_ids = [0]*self.domain_num + self.vectors = [np.array([0.0]*config.d_model)]*self.domain_num + self.total = 0 + self.right = 0 + self.wrong = 0 + self.train_count =0 + task_key_size = (self.task_num,config.d_model) + domain_key_size = (self.domain_num,config.d_model) + self.task_keys = nn.Parameter(torch.ones(task_key_size)) + nn.init.xavier_uniform_(self.task_keys) + self.prompt_embeddings = nn.Embedding(1600, config.d_model) + self.domain_keys = nn.Parameter(torch.ones(domain_key_size)) + nn.init.xavier_uniform_(self.domain_keys) + + self.init_weights() + # Model parallel + self.model_parallel = False + self.device_map = None + + + + + + + def reset_train_count(self): + self.train_count = 0 + + def count_clear(self): + self.total = 0 + self.right = 0 + self.wrong = 0 + + def get_max_keys(self,querys,keys,ids): + domain_sims = torch.matmul(querys,keys.t()) + top5 = torch.topk(domain_sims,5).indices + + _,indices = torch.sort(domain_sims,dim=1,descending=True) + indices = indices.detach().cpu().numpy().tolist() + + if ids is not None: #only training + self.start_step-=1 + + if self.start_step<=0: + + device_index = ids.device.index + + + ids_ = ids.detach().cpu().numpy().tolist() + tmp = self.buffer_ids + for idx,(sub_indices,idd) in enumerate(zip(indices,ids_)): + if idd in tmp: + continue + + for choice in sub_indices: + if domain_sims[idx][choice].item()>self.buffers[choice]: + self.buffers[choice]=domain_sims[idx][choice].item() + self.buffer_ids[choice]=idd + self.vectors[choice]=querys[idx].detach().cpu().numpy() + + + break + + return top5 + + def get_mmr_keys(self,querys,keys,ids): + + domain_sims = torch.matmul(querys,keys.t()) + top1 = torch.topk(domain_sims,1).indices + _,indices = torch.sort(domain_sims,dim=1,descending=True) + indices = indices.detach().cpu().numpy().tolist() + selected = top1.detach().cpu().numpy().tolist() + if ids is not None: + self.start_step-=1 + + if self.start_step<=0: + + device_index = ids.device.index + ids_ = ids.detach().cpu().numpy().tolist() + tmp = self.buffer_ids + for idx,(sub_indices,idd) in enumerate(zip(indices,ids_)): + if idd in tmp: + continue + for choice in sub_indices: + if domain_sims[idx][choice].item()>self.buffers[choice]: + self.buffers[choice]=domain_sims[idx][choice].item() + self.buffer_ids[choice]=idd + self.vectors[choice]=querys[idx].detach().cpu().numpy() + break + + assert len(top1.size())==2 + self_sims = torch.matmul(keys,keys.t()) + + self_sims = self_sims.detach().cpu().numpy() + domain_sims = domain_sims.detach().cpu().numpy() + all_selected = [] + for line in range(domain_sims.shape[0]): + already = selected[line] + domain_sims_line = domain_sims[line] + for idx in range(4): + sims_from = self_sims[already] + max_sims = np.max(sims_from,axis=0) + domain_aggregate = domain_sims_line*0.6-max_sims*0.4 + ids = np.argsort(-domain_aggregate).tolist() + for item in ids: + if (not item in already): + already.append(item) + break + all_selected.append(already) + # print(all_selected) + all_selected=torch.LongTensor(np.array(all_selected)).to(keys.device) + return all_selected + + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + self.block2[layer] = self.block2[leyer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + self.final_layer_norm2 = self.final_layer_norm2.to(self.last_device) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.block2[i] = self.block2[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + self.final_layer_norm2 = self.final_layer_norm2.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def get_query_vector( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None + ): + assert (self.name=="encoder") +################## + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + assert (inputs_embeds is None and input_ids is not None) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + + inputs_embeds = self.embed_tokens(input_ids * (input_ids >= 0).int()) * (input_ids >= 0).int().unsqueeze(2) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block2) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout2(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block2, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + + position_bias = layer_outputs[2] + + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm2(hidden_states) + hidden_states = self.dropout2(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + assert False + + outs = BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + hidden_query_states = outs[0] + query_vector = torch.mean(hidden_query_states,dim=1) + + return query_vector + def add_negs(self,ids,vectors): + self.memory_ids.extend(ids) + self.memory_querys.extend(vectors) + self.mm_ids = torch.Tensor(self.memory_ids)#.to(self.prompt_embeddings.device) + + self.mm_querys = torch.Tensor(np.array(self.memory_querys))#.to(self.prompt_embeddings.device) + + + def forward( + self, + input_ids=None, + attention_mask=None, + id=None, + prototype_vectors = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + train=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_prompt_recored = None, + ): + + if self.name=="encoder": + + query_vector = self.get_query_vector( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + else: + query_vector = None + + + allkey_loss = None + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + if True: + if (self.name=="encoder"): + + #add prompt + query_vector = query_vector/torch.norm(query_vector, p=2, dim=1, keepdim=True) + + tk = self.task_keys/torch.norm(self.task_keys,p=2,dim=1,keepdim=True) + + dk = self.domain_keys/torch.norm(self.domain_keys,p=2,dim=1,keepdim=True) + + query_vector[torch.isnan(query_vector)] = 0 + tk[torch.isnan(tk)]=0 + dk[torch.isnan(dk)]=0 + + if prototype_vectors is not None: + prototype_vectors = prototype_vectors/torch.norm(prototype_vectors,p=2,dim=1,keepdim=True) + prototype_vectors[torch.isnan(prototype_vectors)] = 0 + sims_prototype = torch.matmul(prototype_vectors,dk.t()) + + guess_domain_key_indexs = self.get_mmr_keys(query_vector,dk,ids=id) + + sims_domain = torch.matmul(query_vector,dk.t()) + sims = torch.matmul(query_vector,tk.t()) + guess_task_key_indexs = torch.topk(sims,1).indices.squeeze() + if (len(guess_task_key_indexs.size())==0): + guess_task_key_indexs = guess_task_key_indexs.unsqueeze(0) + + + inputs_embeds = self.embed_tokens(input_ids * (input_ids >= 0).int()) * (input_ids >= 0).int().unsqueeze(2) + prompt_ids = - (input_ids + 1) + format_id = (prompt_ids[0][10].item()-300)//10 + + if train is None: + allkey_loss = None + disp_domain = guess_domain_key_indexs.repeat_interleave(self.domain_size,dim=1)*self.domain_size + + prompt_ids[:,self.prompt_num*2:self.prompt_num*2+5*self.domain_size]+=disp_domain + + + else: + allkey_loss = 0 + + domain_similarity = torch.sum(torch.gather(sims_domain,1,guess_domain_key_indexs))/guess_domain_key_indexs.size(0)/5 + domain_cosine = 1-domain_similarity + + domain_self_similarity = torch.matmul(dk,dk.t()) + _I = torch.eye(domain_self_similarity.size(0)).to(sims_domain.device) + margined = domain_self_similarity-0.85-_I + zeros_margined = torch.zeros(domain_self_similarity.size(0),domain_self_similarity.size(1)).to(sims_domain.device) + margin_loss = torch.max(zeros_margined,margined) + + diversity_reg = torch.sum(margin_loss)/domain_self_similarity.size(0) + + + + allkey_loss+=domain_cosine + allkey_loss+=diversity_reg + + disp_domain = guess_domain_key_indexs.repeat_interleave(self.domain_size,dim=1)*self.domain_size + prompt_ids[:,self.prompt_num*2:self.prompt_num*2+5*self.domain_size]+=disp_domain + + prompt_emb = self.prompt_embeddings(prompt_ids * (prompt_ids >= 0).int()) * (prompt_ids >= 0).int().unsqueeze(2) + + inputs_embeds = inputs_embeds + prompt_emb + elif(self.name=="decoder"): + inputs_embeds = self.embed_tokens(input_ids) + + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + assert self.is_decoder, f":obj:`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + if allkey_loss is not None: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ),allkey_loss + else: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +T5_START_DOCSTRING = r""" + + The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer + `__ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, + Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text + denoising generative setting. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + detail. + + `What are input IDs? <../glossary.html#input-ids>`__ + + To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training + <./t5.html#training>`__. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__ + + T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If + :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see + :obj:`past_key_values`). + + To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training + <./t5.html#training>`__. + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): + Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: + `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a + sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded + representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` + have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert + :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` + takes the value of :obj:`inputs_embeds`. + + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +T5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + detail. + + To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training + <./t5.html#training>`__. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", + T5_START_DOCSTRING, +) +class T5Model(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import T5Tokenizer, T5Model + + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5Model.from_pretrained('t5-small') + + >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + >>> last_hidden_states = outputs.last_hidden_state + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +#dingwei2 + + + +@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) +class T5ForConditionalGeneration(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + r"lm_head\.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + + def __init__(self, config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared,"encoder") + # self.query_encoder = T5Stack(encoder_config, self.shared,"encoder") + #self.query_pooler = + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + + self.decoder = T5Stack(decoder_config, self.shared,"decoder") + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gen_mode = False + + + def freeze_all(self): + self.gen_mode = True + self.encoder.gen_mode = True + for param in self.encoder.parameters(): + param.requires_grad = False + for param in self.decoder.parameters(): + param.requires_grad = False + for param in self.encoder.prompt_embeddings.parameters(): + param.requires_grad = True + self.encoder.task_keys.requires_grad = True + + def unfreeze_all(self): + self.gen_mode = False + self.encoder.gen_mode = False + for param in self.encoder.parameters(): + param.requires_grad = True + for param in self.decoder.parameters(): + param.requires_grad = True + for param in self.encoder.block2.parameters(): + param.requires_grad=False + for param in self.encoder.final_layer_norm2.parameters(): + param.requires_grad=False + for param in self.encoder.dropout2.parameters(): + param.requires_grad=False + + def copy_encoder(self): + self.encoder.block2= copy.deepcopy(self.encoder.block) + self.encoder.final_layer_norm2 = copy.deepcopy(self.encoder.final_layer_norm) + self.encoder.dropout2 = copy.deepcopy(self.encoder.dropout) + for param in self.encoder.block2.parameters(): + param.requires_grad=False + for param in self.encoder.final_layer_norm2.parameters(): + param.requires_grad=False + for param in self.encoder.dropout2.parameters(): + param.requires_grad=False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + id=None, + prototype_vectors=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + train=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_prompt_recored=None, + ): + + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., + config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for + labels in ``[0, ..., config.vocab_size]`` + + Returns: + + Examples:: + + >>> from transformers import T5Tokenizer, T5ForConditionalGeneration + + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5ForConditionalGeneration.from_pretrained('t5-small') + + >>> input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids + >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # print("end use query") + # hidden_query_states = query_outputs[0] + + #last_hidden_state + + + sim_loss = None + + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + if labels is not None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task_prompt_recored=task_prompt_recored, + train=train, + ) + else: + encoder_outputs,sim_loss = self.encoder( + input_ids=input_ids, + id = id, + prototype_vectors = prototype_vectors, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task_prompt_recored=task_prompt_recored, + train=train, + ) + + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + + hidden_states = encoder_outputs[0] + + + + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # If decoding with past key value states, only the last tokens + # should be given as an input + if past_key_values is not None: + assert labels is None, "Decoder should not use cached key value states when training." + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_inputs_embeds is not None: + decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim ** -0.5) + + lm_logits = self.lm_head(sequence_output) + # print("get_lm_logits") + loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + + + + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # task similarity loss + # print(loss,'\n') + # print("get_sim_loss") + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + if sim_loss is not None: + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + + ),sim_loss + else: + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + if "task_prompt_recored" in kwargs: + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + "task_prompt_recored": kwargs["task_prompt_recored"] + + } + else: + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class T5EncoderModel(T5PreTrainedModel): + authorized_missing_keys = [ + r"encoder\.embed_tokens\.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import T5Tokenizer, T5EncoderModel + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5EncoderModel.from_pretrained('t5-small') + >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs diff --git a/diana/models/metat5.py b/diana/models/metat5.py new file mode 100755 index 0000000..e361448 --- /dev/null +++ b/diana/models/metat5.py @@ -0,0 +1,2409 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model. """ + + +import copy +import math +import os +import warnings +import numpy as np +from random import random +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint +from .utils import * +from transformers.activations import ACT2FN +from transformers.file_utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + replace_return_docstrings, +) +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.models.t5.configuration_t5 import T5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_TOKENIZER_FOR_DOC = "T5Tokenizer" +_CHECKPOINT_FOR_DOC = "t5-small" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### +T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + # See all T5 models at https://huggingface.co/models?filter=t5 +] + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (:obj:`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the t5 models have the + following number of attention modules: + + - t5-small: 6 + - t5-base: 12 + - t5-large: 24 + - t5-3b: 24 + - t5-11b: 24 + + Example:: + + # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules: + model = T5ForConditionalGeneration.from_pretrained('t5-3b') + device_map = {0: [0, 1, 2], + + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23]} + model.parallelize(device_map) +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example:: + + # On a 4 GPU machine with t5-3b: + model = T5ForConditionalGeneration.from_pretrained('t5-3b') + device_map = {0: [0, 1, 2], + + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23]} + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() +""" + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # layer norm should always be calculated in float32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into float16 if necessary + if self.weight.dtype == torch.float16: + hidden_states = hidden_states.to(torch.float16) + return self.weight * hidden_states + + +class T5DenseReluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = nn.functional.relu(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedGeluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.gelu_act = ACT2FN["gelu_new"] + + def forward(self, hidden_states): + hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config): + super().__init__() + if config.feed_forward_proj == "relu": + self.DenseReluDense = T5DenseReluDense(config) + elif config.feed_forward_proj == "gated-gelu": + self.DenseReluDense = T5DenseGatedGeluDense(config) + else: + raise ValueError( + f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`" + ) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__(self, config: T5Config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False) + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + int_seq_length = int(seq_length) + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.training and self.gradient_checkpointing: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -int_seq_length:, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + scores += position_bias + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + + if past_key_value is not None: + assert self.is_decoder, "Only decoder can use `past_key_values`" + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}." + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + is_parallelizable = True + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, T5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, T5DenseReluDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5DenseGatedGeluDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert ( + decoder_start_token_id is not None + ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" + + return shifted_input_ids + + +class T5Stack(T5PreTrainedModel): + def __init__(self, config, embed_tokens=None,name=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + self.name = name + self.block = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + if (self.name=="encoder"): + self.block2 = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm2 = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout2 = nn.Dropout(config.dropout_rate) + self.gen_mode = False + + + if(self.name=="encoder"): + self.criterion = BoundaryLoss() + self.prompt_num=40 + self.start_step = 0 + self.format_num = 3 + self.task_num = 11 + self.mm_ids = None + self.mm_querys = None + self.memory_querys = [] + self.memory_ids = [] + self.format2general={0:4,1:5,2:10} + self.format2tasks={0:[0,3,4],1:[1,6,5],2:[2,7,8,9,10]} + self.general_num = 1 + self.domain_num = 30 + self.domain_size = 20 + self.buffers = [-1.0]*self.domain_num + self.buffer_ids = [0]*self.domain_num + self.vectors = [np.array([0.0]*config.d_model)]*self.domain_num + self.total = 0 + self.right = 0 + self.wrong = 0 + self.train_count =0 + task_key_size = (self.task_num,config.d_model) + domain_key_size = (self.domain_num,config.d_model) + self.task_keys = nn.Parameter(torch.ones(task_key_size)) + nn.init.xavier_uniform_(self.task_keys) + self.prompt_embeddings = nn.Embedding(1600, config.d_model) + self.domain_keys = nn.Parameter(torch.ones(domain_key_size)) + nn.init.xavier_uniform_(self.domain_keys) + + self.init_weights() + # Model parallel + self.model_parallel = False + self.device_map = None + + + + + + + def reset_train_count(self): + self.train_count = 0 + + + + def get_max_keys(self,querys,keys,ids): + domain_sims = torch.matmul(querys,keys.t()) + top5 = torch.topk(domain_sims,5).indices + + _,indices = torch.sort(domain_sims,dim=1,descending=True) + indices = indices.detach().cpu().numpy().tolist() + + if ids is not None: #only training + self.start_step-=1 + + if self.start_step<=0: + + device_index = ids.device.index + + + ids_ = ids.detach().cpu().numpy().tolist() + tmp = self.buffer_ids + for idx,(sub_indices,idd) in enumerate(zip(indices,ids_)): + if idd in tmp: + continue + + for choice in sub_indices: + if domain_sims[idx][choice].item()>self.buffers[choice]: + self.buffers[choice]=domain_sims[idx][choice].item() + self.buffer_ids[choice]=idd + self.vectors[choice]=querys[idx].detach().cpu().numpy() + + + break + + return top5 + + def get_mmr_keys(self,querys,keys,ids): + + domain_sims = torch.matmul(querys,keys.t()) + top1 = torch.topk(domain_sims,1).indices + _,indices = torch.sort(domain_sims,dim=1,descending=True) + indices = indices.detach().cpu().numpy().tolist() + selected = top1.detach().cpu().numpy().tolist() + if ids is not None: + self.start_step-=1 + + if self.start_step<=0: + + device_index = ids.device.index + ids_ = ids.detach().cpu().numpy().tolist() + tmp = self.buffer_ids + for idx,(sub_indices,idd) in enumerate(zip(indices,ids_)): + if idd in tmp: + continue + for choice in sub_indices: + if domain_sims[idx][choice].item()>self.buffers[choice]: + self.buffers[choice]=domain_sims[idx][choice].item() + self.buffer_ids[choice]=idd + self.vectors[choice]=querys[idx].detach().cpu().numpy() + break + + assert len(top1.size())==2 + self_sims = torch.matmul(keys,keys.t()) + + self_sims = self_sims.detach().cpu().numpy() + domain_sims = domain_sims.detach().cpu().numpy() + all_selected = [] + for line in range(domain_sims.shape[0]): + already = selected[line] + domain_sims_line = domain_sims[line] + for idx in range(4): + sims_from = self_sims[already] + max_sims = np.max(sims_from,axis=0) + domain_aggregate = domain_sims_line*0.6-max_sims*0.4 + ids = np.argsort(-domain_aggregate).tolist() + for item in ids: + if (not item in already): + already.append(item) + break + all_selected.append(already) + # print(all_selected) + all_selected=torch.LongTensor(np.array(all_selected)).to(keys.device) + return all_selected + + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + self.block2[layer] = self.block2[leyer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + self.final_layer_norm2 = self.final_layer_norm2.to(self.last_device) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.block2[i] = self.block2[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + self.final_layer_norm2 = self.final_layer_norm2.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def get_query_vector( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None + ): + assert (self.name=="encoder") +################## + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + assert (inputs_embeds is None and input_ids is not None) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + + inputs_embeds = self.embed_tokens(input_ids * (input_ids >= 0).int()) * (input_ids >= 0).int().unsqueeze(2) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block2) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout2(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block2, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + + position_bias = layer_outputs[2] + + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm2(hidden_states) + hidden_states = self.dropout2(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + assert False + + outs = BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + hidden_query_states = outs[0] + query_vector = torch.mean(hidden_query_states,dim=1) + + return query_vector + def add_negs(self,ids,vectors): + self.memory_ids.extend(ids) + self.memory_querys.extend(vectors) + self.mm_ids = torch.Tensor(self.memory_ids)#.to(self.prompt_embeddings.device) + + self.mm_querys = torch.Tensor(np.array(self.memory_querys))#.to(self.prompt_embeddings.device) + + + def forward( + self, + input_ids=None, + attention_mask=None, + id=None, + prototype_vectors = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + train=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_prompt_recored = None, + ): + + if self.name=="encoder": + + query_vector = self.get_query_vector( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + else: + query_vector = None + + + allkey_loss = None + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + if True: + if (self.name=="encoder"): + + #add prompt + query_vector = query_vector/torch.norm(query_vector, p=2, dim=1, keepdim=True) + + tk = self.task_keys/torch.norm(self.task_keys,p=2,dim=1,keepdim=True) + + dk = self.domain_keys/torch.norm(self.domain_keys,p=2,dim=1,keepdim=True) + + query_vector[torch.isnan(query_vector)] = 0 + tk[torch.isnan(tk)]=0 + dk[torch.isnan(dk)]=0 + + if prototype_vectors is not None: + prototype_vectors = prototype_vectors/torch.norm(prototype_vectors,p=2,dim=1,keepdim=True) + prototype_vectors[torch.isnan(prototype_vectors)] = 0 + sims_prototype = torch.matmul(prototype_vectors,dk.t()) + + guess_domain_key_indexs = self.get_mmr_keys(query_vector,dk,ids=id) + + sims_domain = torch.matmul(query_vector,dk.t()) + sims = torch.matmul(query_vector,tk.t()) + guess_task_key_indexs = torch.topk(sims,1).indices.squeeze() + if (len(guess_task_key_indexs.size())==0): + guess_task_key_indexs = guess_task_key_indexs.unsqueeze(0) + + + inputs_embeds = self.embed_tokens(input_ids * (input_ids >= 0).int()) * (input_ids >= 0).int().unsqueeze(2) + prompt_ids = - (input_ids + 1) + format_id = (prompt_ids[0][10].item()-500)//self.prompt_num + task_id = prompt_ids[:,50]//self.prompt_num + + if train is None: + allkey_loss = None + disp_domain = guess_domain_key_indexs.repeat_interleave(self.domain_size,dim=1)*self.domain_size + + prompt_ids[:,10+self.prompt_num*2:10+self.prompt_num*2+5*self.domain_size]+=disp_domain + delta_task = guess_task_key_indexs + + centro = self.task_keys/torch.norm(self.task_keys, p=2, dim=1, keepdim=True) + logits = cosine_metric(query_vector,centro) + for _item in range(self.task_num): + if not (_item in self.format2tasks[format_id]): + logits[:,_item]=-1.0 + probs, preds = F.softmax(logits.detach(), dim = 1).max(dim = 1) + cos_distance = 1-(query_vector*centro[preds]).sum(dim=1) + + preds_ = copy.deepcopy(preds.detach()) + try: + preds[cos_distance >= self.delta[preds]] = self.format2general[format_id] + except: + pass + guess_task_key_indexs = preds.squeeze() + if (len(guess_task_key_indexs.size())==0): + guess_task_key_indexs = guess_task_key_indexs.unsqueeze(0) + + delta_task = delta_task.unsqueeze(1).repeat(repeats=(1,40))*self.prompt_num + prompt_ids[:,10+self.prompt_num:10+self.prompt_num*2]+=delta_task + + else: + self.train_count+=3e-4 + _r = random() + if _r>0.95: + label_id = copy.deepcopy(task_id) + label_id = label_id.squeeze() + if (len(label_id.size())==0): + label_id = label_id.unsqueeze(0) + allkey_loss,self.delta = self.criterion(query_vector,self.task_keys,label_id) + + else: + if _r3e-1: + allkey_loss=0 + + + domain_similarity = torch.sum(torch.gather(sims_domain,1,guess_domain_key_indexs))/guess_domain_key_indexs.size(0)/5 + domain_cosine = 1-domain_similarity + + domain_self_similarity = torch.matmul(dk,dk.t()) + _I = torch.eye(domain_self_similarity.size(0)).to(sims_domain.device) + margined = domain_self_similarity-0.85-_I + zeros_margined = torch.zeros(domain_self_similarity.size(0),domain_self_similarity.size(1)).to(sims_domain.device) + margin_loss = torch.max(zeros_margined,margined) + + diversity_reg = torch.sum(margin_loss)/domain_self_similarity.size(0) + + + + allkey_loss+=domain_cosine + allkey_loss+=diversity_reg + + disp_domain = guess_domain_key_indexs.repeat_interleave(self.domain_size,dim=1)*self.domain_size + prompt_ids[:,10+self.prompt_num*2:10+self.prompt_num*2+5*self.domain_size]+=disp_domain + + prompt_emb = self.prompt_embeddings(prompt_ids * (prompt_ids >= 0).int()) * (prompt_ids >= 0).int().unsqueeze(2) + + inputs_embeds = inputs_embeds + prompt_emb + elif(self.name=="decoder"): + inputs_embeds = self.embed_tokens(input_ids) + + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + assert self.is_decoder, f":obj:`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + if allkey_loss is not None: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ),allkey_loss + else: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +T5_START_DOCSTRING = r""" + + The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer + `__ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, + Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text + denoising generative setting. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + detail. + + `What are input IDs? <../glossary.html#input-ids>`__ + + To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training + <./t5.html#training>`__. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__ + + T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If + :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see + :obj:`past_key_values`). + + To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training + <./t5.html#training>`__. + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): + Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: + `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a + sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded + representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` + have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert + :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` + takes the value of :obj:`inputs_embeds`. + + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +T5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + detail. + + To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training + <./t5.html#training>`__. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", + T5_START_DOCSTRING, +) +class T5Model(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import T5Tokenizer, T5Model + + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5Model.from_pretrained('t5-small') + + >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + >>> last_hidden_states = outputs.last_hidden_state + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +#dingwei2 + + + +@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) +class T5ForConditionalGeneration(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + r"lm_head\.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + + def __init__(self, config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared,"encoder") + # self.query_encoder = T5Stack(encoder_config, self.shared,"encoder") + #self.query_pooler = + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + + self.decoder = T5Stack(decoder_config, self.shared,"decoder") + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gen_mode = False + + + def freeze_all(self): + self.gen_mode = True + self.encoder.gen_mode = True + for param in self.encoder.parameters(): + param.requires_grad = False + for param in self.decoder.parameters(): + param.requires_grad = False + for param in self.encoder.prompt_embeddings.parameters(): + param.requires_grad = True + self.encoder.task_keys.requires_grad = True + + def unfreeze_all(self): + self.gen_mode = False + self.encoder.gen_mode = False + for param in self.encoder.parameters(): + param.requires_grad = True + for param in self.decoder.parameters(): + param.requires_grad = True + for param in self.encoder.block2.parameters(): + param.requires_grad=False + for param in self.encoder.final_layer_norm2.parameters(): + param.requires_grad=False + for param in self.encoder.dropout2.parameters(): + param.requires_grad=False + + def copy_encoder(self): + self.encoder.block2= copy.deepcopy(self.encoder.block) + self.encoder.final_layer_norm2 = copy.deepcopy(self.encoder.final_layer_norm) + self.encoder.dropout2 = copy.deepcopy(self.encoder.dropout) + for param in self.encoder.block2.parameters(): + param.requires_grad=False + for param in self.encoder.final_layer_norm2.parameters(): + param.requires_grad=False + for param in self.encoder.dropout2.parameters(): + param.requires_grad=False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + id=None, + prototype_vectors=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + train=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_prompt_recored=None, + ): + + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., + config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for + labels in ``[0, ..., config.vocab_size]`` + + Returns: + + Examples:: + + >>> from transformers import T5Tokenizer, T5ForConditionalGeneration + + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5ForConditionalGeneration.from_pretrained('t5-small') + + >>> input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids + >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # print("end use query") + # hidden_query_states = query_outputs[0] + + #last_hidden_state + + + sim_loss = None + + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + if labels is not None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task_prompt_recored=task_prompt_recored, + train=train, + ) + else: + encoder_outputs,sim_loss = self.encoder( + input_ids=input_ids, + id = id, + prototype_vectors = prototype_vectors, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task_prompt_recored=task_prompt_recored, + train=train, + ) + + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + + hidden_states = encoder_outputs[0] + + + + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # If decoding with past key value states, only the last tokens + # should be given as an input + if past_key_values is not None: + assert labels is None, "Decoder should not use cached key value states when training." + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_inputs_embeds is not None: + decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim ** -0.5) + + lm_logits = self.lm_head(sequence_output) + # print("get_lm_logits") + loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + + + + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # task similarity loss + # print(loss,'\n') + # print("get_sim_loss") + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + if sim_loss is not None: + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + + ),sim_loss + else: + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + if "task_prompt_recored" in kwargs: + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + "task_prompt_recored": kwargs["task_prompt_recored"] + + } + else: + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class T5EncoderModel(T5PreTrainedModel): + authorized_missing_keys = [ + r"encoder\.embed_tokens\.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import T5Tokenizer, T5EncoderModel + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5EncoderModel.from_pretrained('t5-small') + >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs diff --git a/diana/models/metat5nometa.py b/diana/models/metat5nometa.py new file mode 100755 index 0000000..73c5f6d --- /dev/null +++ b/diana/models/metat5nometa.py @@ -0,0 +1,2390 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model. """ + + +import copy +import math +import os +import warnings +import numpy as np +from random import random +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint +from .utils import * +from transformers.activations import ACT2FN +from transformers.file_utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + replace_return_docstrings, +) +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.models.t5.configuration_t5 import T5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_TOKENIZER_FOR_DOC = "T5Tokenizer" +_CHECKPOINT_FOR_DOC = "t5-small" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### +T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + # See all T5 models at https://huggingface.co/models?filter=t5 +] + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (:obj:`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the t5 models have the + following number of attention modules: + + - t5-small: 6 + - t5-base: 12 + - t5-large: 24 + - t5-3b: 24 + - t5-11b: 24 + + Example:: + + # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules: + model = T5ForConditionalGeneration.from_pretrained('t5-3b') + device_map = {0: [0, 1, 2], + + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23]} + model.parallelize(device_map) +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example:: + + # On a 4 GPU machine with t5-3b: + model = T5ForConditionalGeneration.from_pretrained('t5-3b') + device_map = {0: [0, 1, 2], + + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23]} + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() +""" + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # layer norm should always be calculated in float32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into float16 if necessary + if self.weight.dtype == torch.float16: + hidden_states = hidden_states.to(torch.float16) + return self.weight * hidden_states + + +class T5DenseReluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = nn.functional.relu(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedGeluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.gelu_act = ACT2FN["gelu_new"] + + def forward(self, hidden_states): + hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config): + super().__init__() + if config.feed_forward_proj == "relu": + self.DenseReluDense = T5DenseReluDense(config) + elif config.feed_forward_proj == "gated-gelu": + self.DenseReluDense = T5DenseGatedGeluDense(config) + else: + raise ValueError( + f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`" + ) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__(self, config: T5Config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False) + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + int_seq_length = int(seq_length) + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.training and self.gradient_checkpointing: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -int_seq_length:, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + scores += position_bias + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + + if past_key_value is not None: + assert self.is_decoder, "Only decoder can use `past_key_values`" + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}." + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + is_parallelizable = True + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, T5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, T5DenseReluDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5DenseGatedGeluDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert ( + decoder_start_token_id is not None + ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" + + return shifted_input_ids + + +class T5Stack(T5PreTrainedModel): + def __init__(self, config, embed_tokens=None,name=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + self.name = name + self.block = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + if (self.name=="encoder"): + self.block2 = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm2 = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout2 = nn.Dropout(config.dropout_rate) + self.gen_mode = False + + + if(self.name=="encoder"): + self.criterion = BoundaryLoss() + self.prompt_num=40 + self.start_step = 0 + self.format_num = 3 + self.task_num = 11 + self.mm_ids = None + self.mm_querys = None + self.memory_querys = [] + self.memory_ids = [] + self.format2general={0:4,1:5,2:10} + self.format2tasks={0:[0,3,4],1:[1,6,5],2:[2,7,8,9,10]} + self.general_num = 1 + self.domain_num = 30 + self.domain_size = 20 + self.buffers = [-1.0]*self.domain_num + self.buffer_ids = [0]*self.domain_num + self.vectors = [np.array([0.0]*config.d_model)]*self.domain_num + self.total = 0 + self.right = 0 + self.wrong = 0 + self.train_count =0 + task_key_size = (self.task_num,config.d_model) + domain_key_size = (self.domain_num,config.d_model) + self.task_keys = nn.Parameter(torch.ones(task_key_size)) + nn.init.xavier_uniform_(self.task_keys) + self.prompt_embeddings = nn.Embedding(1600, config.d_model) + self.domain_keys = nn.Parameter(torch.ones(domain_key_size)) + nn.init.xavier_uniform_(self.domain_keys) + + self.init_weights() + # Model parallel + self.model_parallel = False + self.device_map = None + + + + + + + def reset_train_count(self): + self.train_count = 0 + + + + def get_max_keys(self,querys,keys,ids): + domain_sims = torch.matmul(querys,keys.t()) + top5 = torch.topk(domain_sims,5).indices + + _,indices = torch.sort(domain_sims,dim=1,descending=True) + indices = indices.detach().cpu().numpy().tolist() + + if ids is not None: #only training + self.start_step-=1 + + if self.start_step<=0: + + device_index = ids.device.index + + + ids_ = ids.detach().cpu().numpy().tolist() + tmp = self.buffer_ids + for idx,(sub_indices,idd) in enumerate(zip(indices,ids_)): + if idd in tmp: + continue + + for choice in sub_indices: + if domain_sims[idx][choice].item()>self.buffers[choice]: + self.buffers[choice]=domain_sims[idx][choice].item() + self.buffer_ids[choice]=idd + self.vectors[choice]=querys[idx].detach().cpu().numpy() + + + break + + return top5 + + def get_mmr_keys(self,querys,keys,ids): + + domain_sims = torch.matmul(querys,keys.t()) + top1 = torch.topk(domain_sims,1).indices + _,indices = torch.sort(domain_sims,dim=1,descending=True) + indices = indices.detach().cpu().numpy().tolist() + selected = top1.detach().cpu().numpy().tolist() + if ids is not None: + self.start_step-=1 + + if self.start_step<=0: + + device_index = ids.device.index + ids_ = ids.detach().cpu().numpy().tolist() + tmp = self.buffer_ids + for idx,(sub_indices,idd) in enumerate(zip(indices,ids_)): + if idd in tmp: + continue + for choice in sub_indices: + if domain_sims[idx][choice].item()>self.buffers[choice]: + self.buffers[choice]=domain_sims[idx][choice].item() + self.buffer_ids[choice]=idd + self.vectors[choice]=querys[idx].detach().cpu().numpy() + break + + assert len(top1.size())==2 + self_sims = torch.matmul(keys,keys.t()) + + self_sims = self_sims.detach().cpu().numpy() + domain_sims = domain_sims.detach().cpu().numpy() + all_selected = [] + for line in range(domain_sims.shape[0]): + already = selected[line] + domain_sims_line = domain_sims[line] + for idx in range(4): + sims_from = self_sims[already] + max_sims = np.max(sims_from,axis=0) + domain_aggregate = domain_sims_line*0.6-max_sims*0.4 + ids = np.argsort(-domain_aggregate).tolist() + for item in ids: + if (not item in already): + already.append(item) + break + all_selected.append(already) + # print(all_selected) + all_selected=torch.LongTensor(np.array(all_selected)).to(keys.device) + return all_selected + + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + self.block2[layer] = self.block2[leyer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + self.final_layer_norm2 = self.final_layer_norm2.to(self.last_device) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.block2[i] = self.block2[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + self.final_layer_norm2 = self.final_layer_norm2.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def get_query_vector( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None + ): + assert (self.name=="encoder") +################## + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + assert (inputs_embeds is None and input_ids is not None) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + + inputs_embeds = self.embed_tokens(input_ids * (input_ids >= 0).int()) * (input_ids >= 0).int().unsqueeze(2) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block2) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout2(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block2, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + + position_bias = layer_outputs[2] + + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm2(hidden_states) + hidden_states = self.dropout2(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + assert False + + outs = BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + hidden_query_states = outs[0] + query_vector = torch.mean(hidden_query_states,dim=1) + + return query_vector + def add_negs(self,ids,vectors): + self.memory_ids.extend(ids) + self.memory_querys.extend(vectors) + self.mm_ids = torch.Tensor(self.memory_ids)#.to(self.prompt_embeddings.device) + + self.mm_querys = torch.Tensor(np.array(self.memory_querys))#.to(self.prompt_embeddings.device) + + + def forward( + self, + input_ids=None, + attention_mask=None, + id=None, + prototype_vectors = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + train=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_prompt_recored = None, + ): + + if self.name=="encoder": + + query_vector = self.get_query_vector( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + else: + query_vector = None + + + allkey_loss = None + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + if True: + if (self.name=="encoder"): + + #add prompt + query_vector = query_vector/torch.norm(query_vector, p=2, dim=1, keepdim=True) + + tk = self.task_keys/torch.norm(self.task_keys,p=2,dim=1,keepdim=True) + + dk = self.domain_keys/torch.norm(self.domain_keys,p=2,dim=1,keepdim=True) + + query_vector[torch.isnan(query_vector)] = 0 + tk[torch.isnan(tk)]=0 + dk[torch.isnan(dk)]=0 + + if prototype_vectors is not None: + prototype_vectors = prototype_vectors/torch.norm(prototype_vectors,p=2,dim=1,keepdim=True) + prototype_vectors[torch.isnan(prototype_vectors)] = 0 + sims_prototype = torch.matmul(prototype_vectors,dk.t()) + + guess_domain_key_indexs = self.get_mmr_keys(query_vector,dk,ids=id) + + sims_domain = torch.matmul(query_vector,dk.t()) + sims = torch.matmul(query_vector,tk.t()) + guess_task_key_indexs = torch.topk(sims,1).indices.squeeze() + if (len(guess_task_key_indexs.size())==0): + guess_task_key_indexs = guess_task_key_indexs.unsqueeze(0) + + + inputs_embeds = self.embed_tokens(input_ids * (input_ids >= 0).int()) * (input_ids >= 0).int().unsqueeze(2) + prompt_ids = - (input_ids + 1) + format_id = (prompt_ids[0][10].item()-500)//self.prompt_num + task_id = prompt_ids[:,50]//self.prompt_num + + if train is None: + allkey_loss = None + + delta_task = guess_task_key_indexs + + centro = self.task_keys/torch.norm(self.task_keys, p=2, dim=1, keepdim=True) + logits = cosine_metric(query_vector,centro) + for _item in range(self.task_num): + if not (_item in self.format2tasks[format_id]): + logits[:,_item]=-1.0 + probs, preds = F.softmax(logits.detach(), dim = 1).max(dim = 1) + cos_distance = 1-(query_vector*centro[preds]).sum(dim=1) + + preds_ = copy.deepcopy(preds.detach()) + try: + preds[cos_distance >= self.delta[preds]] = self.format2general[format_id] + except: + pass + guess_task_key_indexs = preds.squeeze() + if (len(guess_task_key_indexs.size())==0): + guess_task_key_indexs = guess_task_key_indexs.unsqueeze(0) + + delta_task = delta_task.unsqueeze(1).repeat(repeats=(1,40))*self.prompt_num + prompt_ids[:,self.prompt_num:self.prompt_num*2]+=delta_task + + else: + self.train_count+=3e-4 + _r = random() + if _r>0.95: + label_id = copy.deepcopy(task_id) + label_id = label_id.squeeze() + if (len(label_id.size())==0): + label_id = label_id.unsqueeze(0) + allkey_loss,self.delta = self.criterion(query_vector,self.task_keys,label_id) + + else: + if _r3e-1: + allkey_loss=0 + + + + + prompt_emb = self.prompt_embeddings(prompt_ids * (prompt_ids >= 0).int()) * (prompt_ids >= 0).int().unsqueeze(2) + + inputs_embeds = inputs_embeds + prompt_emb + elif(self.name=="decoder"): + inputs_embeds = self.embed_tokens(input_ids) + + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + assert self.is_decoder, f":obj:`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + if allkey_loss is not None: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ),allkey_loss + else: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +T5_START_DOCSTRING = r""" + + The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer + `__ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, + Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text + denoising generative setting. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + detail. + + `What are input IDs? <../glossary.html#input-ids>`__ + + To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training + <./t5.html#training>`__. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__ + + T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If + :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see + :obj:`past_key_values`). + + To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training + <./t5.html#training>`__. + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): + Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: + `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a + sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded + representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` + have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert + :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` + takes the value of :obj:`inputs_embeds`. + + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +T5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + detail. + + To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training + <./t5.html#training>`__. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", + T5_START_DOCSTRING, +) +class T5Model(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import T5Tokenizer, T5Model + + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5Model.from_pretrained('t5-small') + + >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + >>> last_hidden_states = outputs.last_hidden_state + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +#dingwei2 + + + +@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) +class T5ForConditionalGeneration(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + r"lm_head\.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + + def __init__(self, config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared,"encoder") + # self.query_encoder = T5Stack(encoder_config, self.shared,"encoder") + #self.query_pooler = + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + + self.decoder = T5Stack(decoder_config, self.shared,"decoder") + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gen_mode = False + + + def freeze_all(self): + self.gen_mode = True + self.encoder.gen_mode = True + for param in self.encoder.parameters(): + param.requires_grad = False + for param in self.decoder.parameters(): + param.requires_grad = False + for param in self.encoder.prompt_embeddings.parameters(): + param.requires_grad = True + self.encoder.task_keys.requires_grad = True + + def unfreeze_all(self): + self.gen_mode = False + self.encoder.gen_mode = False + for param in self.encoder.parameters(): + param.requires_grad = True + for param in self.decoder.parameters(): + param.requires_grad = True + for param in self.encoder.block2.parameters(): + param.requires_grad=False + for param in self.encoder.final_layer_norm2.parameters(): + param.requires_grad=False + for param in self.encoder.dropout2.parameters(): + param.requires_grad=False + + def copy_encoder(self): + self.encoder.block2= copy.deepcopy(self.encoder.block) + self.encoder.final_layer_norm2 = copy.deepcopy(self.encoder.final_layer_norm) + self.encoder.dropout2 = copy.deepcopy(self.encoder.dropout) + for param in self.encoder.block2.parameters(): + param.requires_grad=False + for param in self.encoder.final_layer_norm2.parameters(): + param.requires_grad=False + for param in self.encoder.dropout2.parameters(): + param.requires_grad=False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + id=None, + prototype_vectors=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + train=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_prompt_recored=None, + ): + + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., + config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for + labels in ``[0, ..., config.vocab_size]`` + + Returns: + + Examples:: + + >>> from transformers import T5Tokenizer, T5ForConditionalGeneration + + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5ForConditionalGeneration.from_pretrained('t5-small') + + >>> input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids + >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # print("end use query") + # hidden_query_states = query_outputs[0] + + #last_hidden_state + + + sim_loss = None + + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + if labels is not None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task_prompt_recored=task_prompt_recored, + train=train, + ) + else: + encoder_outputs,sim_loss = self.encoder( + input_ids=input_ids, + id = id, + prototype_vectors = prototype_vectors, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task_prompt_recored=task_prompt_recored, + train=train, + ) + + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + + hidden_states = encoder_outputs[0] + + + + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # If decoding with past key value states, only the last tokens + # should be given as an input + if past_key_values is not None: + assert labels is None, "Decoder should not use cached key value states when training." + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_inputs_embeds is not None: + decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim ** -0.5) + + lm_logits = self.lm_head(sequence_output) + # print("get_lm_logits") + loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + + + + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # task similarity loss + # print(loss,'\n') + # print("get_sim_loss") + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + if sim_loss is not None: + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + + ),sim_loss + else: + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + if "task_prompt_recored" in kwargs: + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + "task_prompt_recored": kwargs["task_prompt_recored"] + + } + else: + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class T5EncoderModel(T5PreTrainedModel): + authorized_missing_keys = [ + r"encoder\.embed_tokens\.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import T5Tokenizer, T5EncoderModel + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5EncoderModel.from_pretrained('t5-small') + >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs diff --git a/diana/models/metat5notask.py b/diana/models/metat5notask.py new file mode 100755 index 0000000..14ac21c --- /dev/null +++ b/diana/models/metat5notask.py @@ -0,0 +1,2351 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model. """ + + +import copy +import math +import os +import warnings +import numpy as np +from random import random +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint +from .utils import * +from transformers.activations import ACT2FN +from transformers.file_utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + replace_return_docstrings, +) +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.models.t5.configuration_t5 import T5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_TOKENIZER_FOR_DOC = "T5Tokenizer" +_CHECKPOINT_FOR_DOC = "t5-small" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### +T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + # See all T5 models at https://huggingface.co/models?filter=t5 +] + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (:obj:`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the t5 models have the + following number of attention modules: + + - t5-small: 6 + - t5-base: 12 + - t5-large: 24 + - t5-3b: 24 + - t5-11b: 24 + + Example:: + + # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules: + model = T5ForConditionalGeneration.from_pretrained('t5-3b') + device_map = {0: [0, 1, 2], + + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23]} + model.parallelize(device_map) +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example:: + + # On a 4 GPU machine with t5-3b: + model = T5ForConditionalGeneration.from_pretrained('t5-3b') + device_map = {0: [0, 1, 2], + + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23]} + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() +""" + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # layer norm should always be calculated in float32 + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into float16 if necessary + if self.weight.dtype == torch.float16: + hidden_states = hidden_states.to(torch.float16) + return self.weight * hidden_states + + +class T5DenseReluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = nn.functional.relu(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedGeluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.gelu_act = ACT2FN["gelu_new"] + + def forward(self, hidden_states): + hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config): + super().__init__() + if config.feed_forward_proj == "relu": + self.DenseReluDense = T5DenseReluDense(config) + elif config.feed_forward_proj == "gated-gelu": + self.DenseReluDense = T5DenseGatedGeluDense(config) + else: + raise ValueError( + f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`" + ) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__(self, config: T5Config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False) + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + int_seq_length = int(seq_length) + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.training and self.gradient_checkpointing: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -int_seq_length:, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + scores += position_bias + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + + if past_key_value is not None: + assert self.is_decoder, "Only decoder can use `past_key_values`" + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}." + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + is_parallelizable = True + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, T5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, T5DenseReluDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5DenseGatedGeluDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert ( + decoder_start_token_id is not None + ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" + + return shifted_input_ids + + +class T5Stack(T5PreTrainedModel): + def __init__(self, config, embed_tokens=None,name=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + self.name = name + self.block = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + if (self.name=="encoder"): + self.block2 = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) + self.final_layer_norm2 = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout2 = nn.Dropout(config.dropout_rate) + self.gen_mode = False + + + if(self.name=="encoder"): + self.criterion = BoundaryLoss() + self.prompt_num=40 + self.start_step = 0 + self.format_num = 3 + self.task_num = 11 + self.mm_ids = None + self.mm_querys = None + self.memory_querys = [] + self.memory_ids = [] + self.format2general={0:4,1:5,2:10} + self.format2tasks={0:[0,3,4],1:[1,6,5],2:[2,7,8,9,10]} + self.general_num = 1 + self.domain_num = 30 + self.domain_size = 20 + self.buffers = [-1.0]*self.domain_num + self.buffer_ids = [0]*self.domain_num + self.vectors = [np.array([0.0]*config.d_model)]*self.domain_num + self.total = 0 + self.right = 0 + self.wrong = 0 + self.train_count =0 + task_key_size = (self.task_num,config.d_model) + domain_key_size = (self.domain_num,config.d_model) + self.task_keys = nn.Parameter(torch.ones(task_key_size)) + nn.init.xavier_uniform_(self.task_keys) + self.prompt_embeddings = nn.Embedding(1600, config.d_model) + self.domain_keys = nn.Parameter(torch.ones(domain_key_size)) + nn.init.xavier_uniform_(self.domain_keys) + + self.init_weights() + # Model parallel + self.model_parallel = False + self.device_map = None + + + + + + + def reset_train_count(self): + self.train_count = 0 + + + + def get_max_keys(self,querys,keys,ids): + domain_sims = torch.matmul(querys,keys.t()) + top5 = torch.topk(domain_sims,5).indices + + _,indices = torch.sort(domain_sims,dim=1,descending=True) + indices = indices.detach().cpu().numpy().tolist() + + if ids is not None: #only training + self.start_step-=1 + + if self.start_step<=0: + + device_index = ids.device.index + + + ids_ = ids.detach().cpu().numpy().tolist() + tmp = self.buffer_ids + for idx,(sub_indices,idd) in enumerate(zip(indices,ids_)): + if idd in tmp: + continue + + for choice in sub_indices: + if domain_sims[idx][choice].item()>self.buffers[choice]: + self.buffers[choice]=domain_sims[idx][choice].item() + self.buffer_ids[choice]=idd + self.vectors[choice]=querys[idx].detach().cpu().numpy() + + + break + + return top5 + + def get_mmr_keys(self,querys,keys,ids): + + domain_sims = torch.matmul(querys,keys.t()) + top1 = torch.topk(domain_sims,1).indices + _,indices = torch.sort(domain_sims,dim=1,descending=True) + indices = indices.detach().cpu().numpy().tolist() + selected = top1.detach().cpu().numpy().tolist() + if ids is not None: + self.start_step-=1 + + if self.start_step<=0: + + device_index = ids.device.index + ids_ = ids.detach().cpu().numpy().tolist() + tmp = self.buffer_ids + for idx,(sub_indices,idd) in enumerate(zip(indices,ids_)): + if idd in tmp: + continue + for choice in sub_indices: + if domain_sims[idx][choice].item()>self.buffers[choice]: + self.buffers[choice]=domain_sims[idx][choice].item() + self.buffer_ids[choice]=idd + self.vectors[choice]=querys[idx].detach().cpu().numpy() + break + + assert len(top1.size())==2 + self_sims = torch.matmul(keys,keys.t()) + + self_sims = self_sims.detach().cpu().numpy() + domain_sims = domain_sims.detach().cpu().numpy() + all_selected = [] + for line in range(domain_sims.shape[0]): + already = selected[line] + domain_sims_line = domain_sims[line] + for idx in range(4): + sims_from = self_sims[already] + max_sims = np.max(sims_from,axis=0) + domain_aggregate = domain_sims_line*0.6-max_sims*0.4 + ids = np.argsort(-domain_aggregate).tolist() + for item in ids: + if (not item in already): + already.append(item) + break + all_selected.append(already) + # print(all_selected) + all_selected=torch.LongTensor(np.array(all_selected)).to(keys.device) + return all_selected + + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + self.block2[layer] = self.block2[leyer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + self.final_layer_norm2 = self.final_layer_norm2.to(self.last_device) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.block2[i] = self.block2[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + self.final_layer_norm2 = self.final_layer_norm2.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def get_query_vector( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None + ): + assert (self.name=="encoder") +################## + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + assert (inputs_embeds is None and input_ids is not None) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + + inputs_embeds = self.embed_tokens(input_ids * (input_ids >= 0).int()) * (input_ids >= 0).int().unsqueeze(2) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block2) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout2(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block2, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + + position_bias = layer_outputs[2] + + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm2(hidden_states) + hidden_states = self.dropout2(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + assert False + + outs = BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + hidden_query_states = outs[0] + query_vector = torch.mean(hidden_query_states,dim=1) + + return query_vector + def add_negs(self,ids,vectors): + self.memory_ids.extend(ids) + self.memory_querys.extend(vectors) + self.mm_ids = torch.Tensor(self.memory_ids)#.to(self.prompt_embeddings.device) + + self.mm_querys = torch.Tensor(np.array(self.memory_querys))#.to(self.prompt_embeddings.device) + + + def forward( + self, + input_ids=None, + attention_mask=None, + id=None, + prototype_vectors = None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + train=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_prompt_recored = None, + ): + + if self.name=="encoder": + + query_vector = self.get_query_vector( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + else: + query_vector = None + + + allkey_loss = None + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + if True: + if (self.name=="encoder"): + + #add prompt + query_vector = query_vector/torch.norm(query_vector, p=2, dim=1, keepdim=True) + + tk = self.task_keys/torch.norm(self.task_keys,p=2,dim=1,keepdim=True) + + dk = self.domain_keys/torch.norm(self.domain_keys,p=2,dim=1,keepdim=True) + + query_vector[torch.isnan(query_vector)] = 0 + tk[torch.isnan(tk)]=0 + dk[torch.isnan(dk)]=0 + + if prototype_vectors is not None: + prototype_vectors = prototype_vectors/torch.norm(prototype_vectors,p=2,dim=1,keepdim=True) + prototype_vectors[torch.isnan(prototype_vectors)] = 0 + sims_prototype = torch.matmul(prototype_vectors,dk.t()) + + guess_domain_key_indexs = self.get_mmr_keys(query_vector,dk,ids=id) + + sims_domain = torch.matmul(query_vector,dk.t()) + sims = torch.matmul(query_vector,tk.t()) + guess_task_key_indexs = torch.topk(sims,1).indices.squeeze() + if (len(guess_task_key_indexs.size())==0): + guess_task_key_indexs = guess_task_key_indexs.unsqueeze(0) + + + inputs_embeds = self.embed_tokens(input_ids * (input_ids >= 0).int()) * (input_ids >= 0).int().unsqueeze(2) + prompt_ids = - (input_ids + 1) + format_id = (prompt_ids[0][10].item()-500)//self.prompt_num + + if train is None: + allkey_loss = None + disp_domain = guess_domain_key_indexs.repeat_interleave(self.domain_size,dim=1)*self.domain_size + + prompt_ids[:,self.prompt_num:self.prompt_num+5*self.domain_size]+=disp_domain + delta_task = guess_task_key_indexs + + + + else: + + allkey_loss=0 + + + domain_similarity = torch.sum(torch.gather(sims_domain,1,guess_domain_key_indexs))/guess_domain_key_indexs.size(0)/5 + domain_cosine = 1-domain_similarity + + domain_self_similarity = torch.matmul(dk,dk.t()) + _I = torch.eye(domain_self_similarity.size(0)).to(sims_domain.device) + margined = domain_self_similarity-0.85-_I + zeros_margined = torch.zeros(domain_self_similarity.size(0),domain_self_similarity.size(1)).to(sims_domain.device) + margin_loss = torch.max(zeros_margined,margined) + + diversity_reg = torch.sum(margin_loss)/domain_self_similarity.size(0) + + + + allkey_loss+=domain_cosine + allkey_loss+=diversity_reg + + disp_domain = guess_domain_key_indexs.repeat_interleave(self.domain_size,dim=1)*self.domain_size + prompt_ids[:,self.prompt_num*2:self.prompt_num*2+5*self.domain_size]+=disp_domain + + prompt_emb = self.prompt_embeddings(prompt_ids * (prompt_ids >= 0).int()) * (prompt_ids >= 0).int().unsqueeze(2) + + inputs_embeds = inputs_embeds + prompt_emb + elif(self.name=="decoder"): + inputs_embeds = self.embed_tokens(input_ids) + + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + assert self.is_decoder, f":obj:`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + if allkey_loss is not None: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ),allkey_loss + else: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +T5_START_DOCSTRING = r""" + + The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer + `__ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, + Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text + denoising generative setting. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + detail. + + `What are input IDs? <../glossary.html#input-ids>`__ + + To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training + <./t5.html#training>`__. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__ + + T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If + :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see + :obj:`past_key_values`). + + To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training + <./t5.html#training>`__. + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): + Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: + `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a + sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded + representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` + have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert + :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` + takes the value of :obj:`inputs_embeds`. + + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +T5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + detail. + + To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training + <./t5.html#training>`__. + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", + T5_START_DOCSTRING, +) +class T5Model(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import T5Tokenizer, T5Model + + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5Model.from_pretrained('t5-small') + + >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + >>> last_hidden_states = outputs.last_hidden_state + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +#dingwei2 + + + +@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) +class T5ForConditionalGeneration(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + r"lm_head\.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + + def __init__(self, config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared,"encoder") + # self.query_encoder = T5Stack(encoder_config, self.shared,"encoder") + #self.query_pooler = + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + + self.decoder = T5Stack(decoder_config, self.shared,"decoder") + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gen_mode = False + + + def freeze_all(self): + self.gen_mode = True + self.encoder.gen_mode = True + for param in self.encoder.parameters(): + param.requires_grad = False + for param in self.decoder.parameters(): + param.requires_grad = False + for param in self.encoder.prompt_embeddings.parameters(): + param.requires_grad = True + self.encoder.task_keys.requires_grad = True + + def unfreeze_all(self): + self.gen_mode = False + self.encoder.gen_mode = False + for param in self.encoder.parameters(): + param.requires_grad = True + for param in self.decoder.parameters(): + param.requires_grad = True + for param in self.encoder.block2.parameters(): + param.requires_grad=False + for param in self.encoder.final_layer_norm2.parameters(): + param.requires_grad=False + for param in self.encoder.dropout2.parameters(): + param.requires_grad=False + + def copy_encoder(self): + self.encoder.block2= copy.deepcopy(self.encoder.block) + self.encoder.final_layer_norm2 = copy.deepcopy(self.encoder.final_layer_norm) + self.encoder.dropout2 = copy.deepcopy(self.encoder.dropout) + for param in self.encoder.block2.parameters(): + param.requires_grad=False + for param in self.encoder.final_layer_norm2.parameters(): + param.requires_grad=False + for param in self.encoder.dropout2.parameters(): + param.requires_grad=False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + id=None, + prototype_vectors=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + train=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_prompt_recored=None, + ): + + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., + config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for + labels in ``[0, ..., config.vocab_size]`` + + Returns: + + Examples:: + + >>> from transformers import T5Tokenizer, T5ForConditionalGeneration + + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5ForConditionalGeneration.from_pretrained('t5-small') + + >>> input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids + >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # print("end use query") + # hidden_query_states = query_outputs[0] + + #last_hidden_state + + + sim_loss = None + + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + if labels is not None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task_prompt_recored=task_prompt_recored, + train=train, + ) + else: + encoder_outputs,sim_loss = self.encoder( + input_ids=input_ids, + id = id, + prototype_vectors = prototype_vectors, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + task_prompt_recored=task_prompt_recored, + train=train, + ) + + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + + hidden_states = encoder_outputs[0] + + + + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # If decoding with past key value states, only the last tokens + # should be given as an input + if past_key_values is not None: + assert labels is None, "Decoder should not use cached key value states when training." + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_inputs_embeds is not None: + decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim ** -0.5) + + lm_logits = self.lm_head(sequence_output) + # print("get_lm_logits") + loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + + + + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # task similarity loss + # print(loss,'\n') + # print("get_sim_loss") + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + if sim_loss is not None: + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + + ),sim_loss + else: + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + if "task_prompt_recored" in kwargs: + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + "task_prompt_recored": kwargs["task_prompt_recored"] + + } + else: + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class T5EncoderModel(T5PreTrainedModel): + authorized_missing_keys = [ + r"encoder\.embed_tokens\.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import T5Tokenizer, T5EncoderModel + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5EncoderModel.from_pretrained('t5-small') + >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs diff --git a/diana/models/utils.py b/diana/models/utils.py new file mode 100644 index 0000000..ceb273a --- /dev/null +++ b/diana/models/utils.py @@ -0,0 +1,61 @@ +import torch.nn.functional as F +from torch import nn +import torch +import copy +#adaptive decision boundary +#[code]https://github.com/thuiar/Adaptive-Decision-Boundary +#[paper]https://www.aaai.org/AAAI21Papers/AAAI-9723.ZhangH.pdf +def euclidean_metric(a, b): + n = a.shape[0] + m = b.shape[0] + a = a.unsqueeze(1).expand(n, m, -1) + b = b.unsqueeze(0).expand(n, m, -1) + logits = -((a - b)**2).sum(dim=2) + return logits + +def cosine_metric(a,b): + n = a.shape[0] + m = b.shape[0] + a = a.unsqueeze(1).expand(n, m, -1) + b = b.unsqueeze(0).expand(n, m, -1) + logits = (a*b).sum(dim=2) + # logits = -logits+1 + return logits + + +class BoundaryLoss(nn.Module): + + def __init__(self, num_labels=10, feat_dim=2): + + super(BoundaryLoss, self).__init__() + self.num_labels = num_labels + self.feat_dim = feat_dim + + self.delta = nn.Parameter(torch.ones(20).cuda()) + nn.init.normal_(self.delta) + + + def forward(self, pooled_output, centroids_, labels,group=0): + centroids = copy.deepcopy(centroids_.detach()) + logits = cosine_metric(pooled_output, centroids) + probs, preds = F.softmax(logits.detach(), dim=1).max(dim=1) + delta = F.softplus(self.delta) + #c = torch.Tensor([[0.0]*768]*(centroids[labels]).size(0)).to(pooled_output.device)# + # print(self.delta.size()) + # print("labels",labels.size()) + c = centroids[labels] + c = c/torch.norm(c, p=2, dim=1, keepdim=True) + # c.requires_grad = False + d = delta[labels] + x = pooled_output + + euc_dis = 1-((c*x).sum(dim=1))#torch.norm(x - c,2, 1).view(-1) + pos_mask = (euc_dis > d).type(torch.cuda.FloatTensor) + neg_mask = (euc_dis < d).type(torch.cuda.FloatTensor) + + pos_loss = (euc_dis - d) * pos_mask + neg_loss = (d - euc_dis) * neg_mask + loss = pos_loss.mean() + neg_loss.mean() + + + return loss, delta \ No newline at end of file diff --git a/diana/r.txt b/diana/r.txt new file mode 100644 index 0000000..8f2579d --- /dev/null +++ b/diana/r.txt @@ -0,0 +1,56 @@ +absl-py==1.0.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +async-timeout==4.0.2 +asynctest==0.13.0 +attrs==21.4.0 +certifi==2022.5.18.1 +charset-normalizer==2.0.12 +click==8.1.3 +datasets==2.2.2 +dill==0.3.4 +filelock==3.7.0 +frozenlist==1.3.0 +fsspec==2022.5.0 +future==0.18.2 +fuzzywuzzy==0.18.0 +huggingface-hub==0.7.0 +idna==3.3 +importlib-metadata==4.11.4 +joblib==1.1.0 +multidict==6.0.2 +multiprocess==0.70.12.2 +nlp==0.4.0 +nltk==3.7 +numpy==1.21.6 +packaging==21.3 +pandas==1.1.5 +pyarrow==8.0.0 +pyparsing==3.0.9 +python-dateutil==2.8.2 +pytz==2022.1 +PyYAML==6.0 +regex==2022.4.24 +requests==2.27.1 +responses==0.18.0 +rouge-score==0.0.4 +sacremoses==0.0.53 +scikit-learn==1.0.2 +scipy==1.7.3 +six==1.16.0 +sklearn==0.0 +threadpoolctl==3.1.0 +tokenizers==0.10.3 +tqdm==4.64.0 +transformers==4.12.5 +typing_extensions==4.2.0 +urllib3==1.26.9 +wget==3.2 +xxhash==3.0.0 +yarl==1.7.2 +zipp==3.8.0 +numpy==1.20.3 +rouge_score +omegaconf +jsonlines +spacy \ No newline at end of file diff --git a/diana/readme.md b/diana/readme.md new file mode 100644 index 0000000..883004b --- /dev/null +++ b/diana/readme.md @@ -0,0 +1,55 @@ +# Diana +The PyTorch implementation of paper [Domain Incremental Lifelong Learning in an Open World](https://arxiv.org/abs/2305.06555) (ACL 2023) + +## Requirements +```bash +cd DomainIncrementalLL +pip install -r r.txt +``` +## QA tasks +### Data preparation +Download datasets, unzip and place in /data_process + +[QA Dataset](https://drive.google.com/file/d/1x8vvrdfwKCXpT_M4NA_eso8cX-LlMWsQ/view?usp=share_link) + +Tokenize plain text datasets: +```python +cd downstream +python create_l2p.py --model_name_or_path t5-base --dataset_name null --output_dir null +``` + +### Train and evaluate on 4 GPUs +```python +bash ./diana.sh ./ll 8 48 +``` + +### ablations: No task prompt & No meta prompt + +To evaluate ablation without task prompts: +```bash +bash ./dianawotask.sh ./ll 8 48 +``` +To evaluate ablation without meta prompts: +```bash +bash ./dianawometa.sh ./ll 8 48 +``` + +## DecaNLP tasks + +[Dataset](https://drive.google.com/file/d/1pzIoTzooaecsU4HXP_n4-_7Uuxv9i4A-/view?usp=share_link) + +extract plain texts from raw files: +```bash +cd downstreamdeca +python extract.py +``` +tokenize plain texts: +```bash +python toknize.py +``` + + +Run the training and evaluation script: +```bash +bash ./diana.sh ./ll 8 48 +```