DAMO-ConvAI/metaretriever/run_seq2seq.py

780 lines
32 KiB
Python

#!/usr/bin/env python
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import logging
import os
import sys
from dataclasses import dataclass, field
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from typing import Optional
import numpy as np
from datasets import load_dataset
import transformers
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
default_data_collator,
set_seed
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process
from uie.extraction import constants
from uie.extraction.record_schema import RecordSchema
from uie.extraction.predict_parser import decoding_format_dict
from uie.extraction.extraction_metrics import get_extract_metrics
from uie.extraction.noiser.spot_asoc_noiser import SpotAsocNoiser
from uie.extraction.dataset_processer import PrefixGenerator
from uie.seq2seq.constrained_seq2seq import (
ConstraintSeq2SeqTrainingArguments,
ConstraintSeq2SeqTrainer,
OriginalConstraintSeq2SeqTrainer,
UIEPretrainConstraintSeq2SeqTrainer,
UIEFinetuneConstraintSeq2SeqTrainer,
MetaPretrainConstraintSeq2SeqTrainer,
MetaFinetuneConstraintSeq2SeqTrainer,
)
from uie.seq2seq.data_collator import (
DataCollatorForMetaSeq2Seq,
DynamicSSIGenerator,
)
from uie.seq2seq.features import RecordFeature
from uie.seq2seq.model import PromptSeq2SeqTransformer
from uie.seq2seq.noise_record import create_noised_record
import pdb
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=False,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
from_checkpoint: bool = field(
default=False, metadata={"help": "Whether load from checkpoint to continue learning"}
)
load_config_only: bool = field(
default=False, metadata={"help": "Whether load model config only from checkpoint"}
)
use_prompt_tuning_model: bool = field(
default=False, metadata={"help": "Whether use prompt tuning model"}
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
task: str = field(
default="summarization",
metadata={
"help": "The name of the task, should be summarization (or summarization_{dataset} for evaluating "
"pegasus) or translation (or translation_{xx}_to_{yy})."
},
)
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
text_column: Optional[str] = field(
default='text',
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
)
record_column: Optional[str] = field(
default='record',
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
)
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input evaluation data file to evaluate the metrics (rouge/sacreblue) on "
"(a jsonlines or csv file)."
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (rouge/sacreblue) on "
"(a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_prefix_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum prefix length."
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
"during ``evaluate`` and ``predict``."
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"value if set."
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
"which is used during ``evaluate`` and ``predict``."
},
)
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
},
)
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
meta_negative: int = field(
default=-1, metadata={"help": "Negative Schema Number in Training."}
)
ordered_prompt: bool = field(
default=True,
metadata={
"help": "Whether to sort the spot prompt and asoc prompt or not."
},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
decoding_format: str = field(
default='tree',
metadata={"help": "Decoding Format, valid in %s" % decoding_format_dict.keys()}
)
record_schema: str = field(
default=None, metadata={"help": "The input event schema file."}
)
spot_noise: float = field(
default=0., metadata={"help": "The noise rate of null spot."}
)
asoc_noise: float = field(
default=0., metadata={"help": "The noise rate of null asoc."}
)
meta_positive_rate: float = field(
default=1., metadata={"help": "The keep rate of positive spot."}
)
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ConstraintSeq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
logger.info("Options:")
logger.info(model_args)
logger.info(data_args)
logger.info(training_args)
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
transformers.utils.logging.set_verbosity_info()
logger.info("Training/evaluation parameters %s", training_args)
# Set seed before initializing model.
set_seed(training_args.seed)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files in the summarization task, this script will use the first column for the full texts and the
# second column for the summaries (unless you specify column names for this with the `text_column` and
# `record_column` arguments).
# For translation, only JSON files are supported, with one field named "translation" containing two keys for the
# source and target languages (unless you adapt what follows).
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
else:
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if training_args.do_eval and data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if training_args.do_predict and data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
logger.info(data_files)
datasets = load_dataset("uie_json.py", data_files=data_files, block_size=(10<<22))
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
logger.info(datasets)
# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
logger.info("Load Config: %s" % model_args.config_name if model_args.config_name else model_args.model_name_or_path)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
config.max_length = data_args.max_target_length
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
to_remove_token_list = list()
if tokenizer.bos_token:
to_remove_token_list += [tokenizer.bos_token]
if tokenizer.eos_token:
to_remove_token_list += [tokenizer.eos_token]
if tokenizer.pad_token:
to_remove_token_list += [tokenizer.pad_token]
if model_args.use_prompt_tuning_model:
MODEL = PromptSeq2SeqTransformer
else:
MODEL = AutoModelForSeq2SeqLM
if model_args.load_config_only:
model = MODEL.from_config(config)
else:
model = MODEL.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
mirror='tuna',
)
if training_args.do_train:
to_add_special_token = list()
for special_token in [constants.type_start, constants.type_end, constants.text_start, constants.span_start, constants.spot_prompt, constants.asoc_prompt]:
if special_token not in tokenizer.get_vocab():
to_add_special_token += [special_token]
tokenizer.add_special_tokens(
{"additional_special_tokens": tokenizer.special_tokens_map_extended['additional_special_tokens'] + to_add_special_token}
)
model.resize_token_embeddings(len(tokenizer))
logger.info(tokenizer)
# Set decoder_start_token_id
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
if data_args.record_schema and os.path.exists(data_args.record_schema):
record_schema = RecordSchema.read_from_file(data_args.record_schema)
else:
record_schema = None
if data_args.source_prefix is not None:
if data_args.source_prefix == 'schema':
prefix = PrefixGenerator.get_schema_prefix(schema=record_schema)
elif data_args.source_prefix.startswith('meta'):
prefix = ""
else:
prefix = data_args.source_prefix
else:
prefix = ""
logger.info(f"Prefix: {prefix}")
logger.info(f"Prefix Length: {len(tokenizer.tokenize(prefix))}")
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
if training_args.do_train:
column_names = datasets["train"].column_names
elif training_args.do_eval:
column_names = datasets["validation"].column_names
elif training_args.do_predict:
column_names = datasets["test"].column_names
else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return
# To serialize preprocess_function below, each of those four variables needs to be defined (even if we won't use
# them all).
text_column = data_args.text_column
record_column = data_args.record_column
logger.info('Using src: %s and tgt: %s' % (text_column, record_column))
# Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length
padding = "max_length" if data_args.pad_to_max_length else False
if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
logger.error(
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
)
def preprocess_function(examples):
inputs = examples[text_column]
targets = examples[record_column]
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
model_inputs["text"] = inputs
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(_label if _label != tokenizer.pad_token_id else -100) for _label in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
# set noised record inputs
noised_record_list = []
for idx, noised_record in enumerate(examples["noised_record"]):
if noised_record is None:
tokens = examples["tokens"][idx]
entity_list = examples["entity"][idx]
triple_list = examples["relation"][idx]
event_list = examples["event"][idx]
noised_record = create_noised_record(tokens, entity_list, triple_list, event_list)
noised_record_list.append(noised_record)
model_inputs["noised_record"] = noised_record_list
# model_inputs["noised_record"] = examples["noised_record"]
# others
model_inputs['sample_prompt'] = [False] * len(model_inputs['input_ids'])
if data_args.source_prefix is not None and data_args.source_prefix.startswith('meta'):
model_inputs['spots'] = examples['spot']
model_inputs['asocs'] = examples['asoc']
model_inputs['spot_asoc'] = examples['spot_asoc']
# sample_prompt=True for Finetune and Pretrain
model_inputs['sample_prompt'] = [True] * len(model_inputs['input_ids'])
return model_inputs
def preprocess_function_eval(examples):
model_inputs = preprocess_function(examples)
# sample_prompt=False for evaluation
model_inputs['sample_prompt'] = [False] * len(model_inputs['input_ids'])
return model_inputs
def postprocess_text(x_str):
# Clean `bos` `eos` `pad` for cleaned text
for to_remove_token in to_remove_token_list:
x_str = x_str.replace(to_remove_token, '')
return x_str.strip()
logger.info("Start Data Preprocessing ...")
if training_args.do_train:
train_dataset = datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
features=RecordFeature,
)
if training_args.do_eval:
max_target_length = data_args.val_max_target_length
eval_dataset = datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
eval_dataset = eval_dataset.map(
preprocess_function_eval,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
features=RecordFeature,
)
if training_args.do_predict:
max_target_length = data_args.val_max_target_length
test_dataset = datasets["test"]
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
test_dataset = test_dataset.map(
preprocess_function_eval,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
features=RecordFeature,
)
logger.info("End Data Preprocessing ...")
# Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
if data_args.pad_to_max_length:
data_collator = default_data_collator
elif data_args.source_prefix.startswith('meta'):
if data_args.spot_noise > 0 or data_args.asoc_noise > 0:
if data_args.decoding_format == 'spotasoc':
spot_asoc_nosier = SpotAsocNoiser(
spot_noise_ratio=data_args.spot_noise,
asoc_noise_ratio=data_args.asoc_noise,
null_span=constants.null_span,
)
else:
raise NotImplementedError(
"decoding_format `spotasoc` is not implemented."
)
else:
spot_asoc_nosier = None
data_collator = DataCollatorForMetaSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
max_length=data_args.max_source_length,
max_prefix_length=data_args.max_prefix_length,
max_target_length=data_args.max_target_length,
negative_sampler=DynamicSSIGenerator(
tokenizer=tokenizer,
schema=record_schema,
positive_rate=data_args.meta_positive_rate,
negative=data_args.meta_negative,
ordered_prompt=data_args.ordered_prompt,
),
spot_asoc_nosier=spot_asoc_nosier,
decoding_format=data_args.decoding_format,
)
else:
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
)
def compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=False, clean_up_tokenization_spaces=False)
if data_args.ignore_pad_token_for_loss:
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=False, clean_up_tokenization_spaces=False)
decoded_preds = [postprocess_text(x) for x in decoded_preds]
decoded_labels = [postprocess_text(x) for x in decoded_labels]
result = get_extract_metrics(
pred_lns=decoded_preds,
tgt_lns=decoded_labels,
label_constraint=record_schema,
decoding_format=data_args.decoding_format,
)
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
result = {k: round(v, 4) for k, v in result.items()}
return result
# Initialize our Trainer
if training_args.trainer_type == "uie_pretrain":
TRAINER = UIEPretrainConstraintSeq2SeqTrainer
elif training_args.trainer_type == "uie_finetune":
TRAINER = UIEFinetuneConstraintSeq2SeqTrainer
elif training_args.trainer_type == "meta_pretrain":
TRAINER = MetaPretrainConstraintSeq2SeqTrainer
elif training_args.trainer_type == "meta_finetune":
TRAINER = MetaFinetuneConstraintSeq2SeqTrainer
else:
TRAINER = OriginalConstraintSeq2SeqTrainer
trainer = TRAINER(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
decoding_type_schema=record_schema,
decoding_format=data_args.decoding_format,
source_prefix=prefix,
task=data_args.task,
)
# Training
if training_args.do_train:
if model_args.from_checkpoint:
if last_checkpoint is not None:
checkpoint = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path):
checkpoint = model_args.model_name_or_path
else:
checkpoint = None
else:
checkpoint = None
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
if trainer.is_world_process_zero():
with open(output_train_file, "w") as writer:
logger.info("***** Train results *****")
for key, value in sorted(train_result.metrics.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams)
results = {k: round(v, 4) for k, v in results.items()}
eval_results = trainer.predict(
eval_dataset,
metric_key_prefix="eval",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
)
output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
if trainer.is_world_process_zero():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in sorted(results.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
if training_args.predict_with_generate:
eval_preds = tokenizer.batch_decode(
eval_results.predictions, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
eval_preds = [postprocess_text(pred) for pred in eval_preds]
output_test_preds_file = os.path.join(training_args.output_dir, "eval_preds_seq2seq.txt")
with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(eval_preds))
if training_args.do_predict:
logger.info("*** Test ***")
test_results = trainer.predict(
test_dataset,
metric_key_prefix="test",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
)
test_metrics = test_results.metrics
test_metrics["test_loss"] = round(test_metrics["test_loss"], 4)
output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt")
if trainer.is_world_process_zero():
with open(output_test_result_file, "w") as writer:
logger.info("***** Test results *****")
for key, value in sorted(test_metrics.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode(
test_results.predictions, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
test_preds = [postprocess_text(pred) for pred in test_preds]
output_test_preds_file = os.path.join(training_args.output_dir, "test_preds_seq2seq.txt")
with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(test_preds))
return results
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()