parent
b99e3d74c9
commit
a33cb4f8b9
|
@ -10,7 +10,7 @@ torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \
|
|||
--overwrite_cache \
|
||||
--prompt_column content \
|
||||
--response_column summary \
|
||||
--model_name_or_path chatglm2-6b \
|
||||
--model_name_or_path THUDM/chatglm2-6b \
|
||||
--ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
|
||||
--output_dir ./output/$CHECKPOINT \
|
||||
--overwrite_output_dir \
|
||||
|
|
|
@ -178,7 +178,7 @@ def main():
|
|||
return model_inputs
|
||||
|
||||
def preprocess_function_train(examples):
|
||||
max_seq_length = data_args.max_source_length + data_args.max_target_length
|
||||
max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": [],
|
||||
|
@ -335,7 +335,7 @@ def main():
|
|||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
save_prefixencoder=model_args.pre_seq_len is not None
|
||||
save_changed=model_args.pre_seq_len is not None
|
||||
)
|
||||
|
||||
# Training
|
||||
|
|
3782
ptuning/trainer.py
3782
ptuning/trainer.py
File diff suppressed because it is too large
Load Diff
|
@ -19,7 +19,7 @@ from torch import nn
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
from trainer import Trainer
|
||||
from trainer import PrefixTrainer
|
||||
from transformers.trainer_utils import PredictionOutput
|
||||
from transformers.utils import logging
|
||||
|
||||
|
@ -27,7 +27,7 @@ from transformers.utils import logging
|
|||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqTrainer(Trainer):
|
||||
class Seq2SeqTrainer(PrefixTrainer):
|
||||
def evaluate(
|
||||
self,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
|
|
Loading…
Reference in New Issue