fix error commit & update pro
This commit is contained in:
parent
1cdc7bef6f
commit
64f4da714a
|
@ -4,14 +4,14 @@ Authors: Feifan Song, Bowen Yu, Minghao Li, Haiyang Yu, Fei Huang, Yongbin Li, H
|
|||
arXiv: [Abstract](https://arxiv.org/abs/2306.17492) / [PDF](https://arxiv.org/pdf/2306.17492.pdf)
|
||||
|
||||
## Abstract
|
||||
Large language models (LLMs) often contain misleading content, emphasizing the need to align them with human values to ensure secure AI systems. Reinforcement learning from human feedback (RLHF) has been employed to achieve this alignment by combining a reward model, typically based on Bradley-Terry paired comparison, with an RL algorithm such as Proximal Policy Optimization (PPO) to optimize LLM responses. However, RLHF exhibits complexity, instability, and sensitivity to hyperparameters. In this paper, we propose Preference Ranking Optimization (PRO) as an alternative to PPO for directly aligning LLMs with the Bradley-Terry comparison. PRO extends the pairwise Bradley-Terry comparison to accommodate preference rankings of any length. By iteratively contrasting the likelihood of generating responses, PRO instructs the LLM to prioritize the best response while progressively ranking the remaining responses. In this manner, PRO effectively transforms human alignment into aligning the probability ranking of n responses generated by LLM with the preference ranking of humans towards these responses. Experiments have shown that PRO outperforms existing alignment algorithms, achieving comparable results to ChatGPT and human responses through automatic-based, reward-based, GPT-4, and human evaluations. Furthermore, we demonstrate that longer, more diverse, and higher-quality preference ranking sequences can consistently enhance the performance of human alignment.
|
||||
Large language models (LLMs) often contain misleading content, emphasizing the need to align them with human values to ensure secure AI systems. Reinforcement learning from human feedback (RLHF) has been employed to achieve this alignment. However, it encompasses two main drawbacks: (1) RLHF exhibits complexity, instability, and sensitivity to hyperparameters in contrast to SFT. (2) Despite massive trial-and-error, multiple sampling is reduced to pair-wise contrast, thus lacking contrasts from a macro perspective. In this paper, we propose Preference Ranking Optimization (PRO) as an efficient SFT algorithm to directly fine-tune LLMs for human alignment. PRO extends the pair-wise contrast to accommodate preference rankings of any length. By iteratively contrasting candidates, PRO instructs the LLM to prioritize the best response while progressively ranking the rest responses. In this manner, PRO effectively transforms human alignment into aligning the probability ranking of n responses generated by LLM with the preference ranking of humans towards these responses. Experiments have shown that PRO outperforms baseline algorithms, achieving comparable results to ChatGPT and human responses through automatic-based, reward-based, GPT-4, and human evaluations.
|
||||
|
||||
## The pipeline of PRO
|
||||
<div align="center"><img src="./resources/pipeline.jpg" style="zoom:100%"></div>
|
||||
|
||||
## Results
|
||||
### Automatic Evaluation
|
||||
<div align="center"><img src="./resources/automatic.jpg" style="zoom:100%"></div>
|
||||
### Automatic Evaluation on *HH-RLHF*
|
||||
<div align="center"><img src="./resources/automatic_hh.jpg" style="zoom:100%"></div>
|
||||
|
||||
### GPT-4 Evaluation
|
||||
<div align="center"><img src="./resources/gpt4.jpg" style="zoom:33%"></div>
|
||||
|
@ -19,32 +19,63 @@ Large language models (LLMs) often contain misleading content, emphasizing the n
|
|||
### Human Evaluation
|
||||
<div align="center"><img src="./resources/human.jpg" style="zoom:33%"></div>
|
||||
|
||||
### Automatic Evaluation on *Summarize From Feedback*
|
||||
<div align="center"><img src="./resources/automatic_summarize.jpg" style="zoom:50%"></div>
|
||||
|
||||
## Running!
|
||||
### Data Preparation
|
||||
1. Download [data.zip](https://ylab-mobile-prod.oss-cn-beijing.aliyuncs.com/yueli.ybw/pro_data.zip) and unzip it.
|
||||
2. Place the unzipped ```data/``` folder in the root directory of the project.
|
||||
3. You can also get the raw data from [this repo](https://github.com/anthropics/hh-rlhf), and run the following command to preprocess it to get the same data as ```train_len2/``` in ```data.zip```:
|
||||
We provide the preprocessed data for training and testing, which can be get with following steps:
|
||||
1. Download [data.zip](https://ylab-mobile-prod.oss-cn-beijing.aliyuncs.com/yueli.ybw/data.zip) and unzip it.
|
||||
2. Place the unzipped ```data``` folder in the root directory of the project.
|
||||
|
||||
Besides, we also provide the scripts for preprocessing the raw data. Please follow the steps below to prepare the data:
|
||||
1. Create a directory named ```data``` in the root directory of this project.
|
||||
2. Create a directory named ```data/raw_data``` in the ```data``` directory.
|
||||
3. Download the raw data from [*HH-RLHF*](https://github.com/anthropics/hh-rlhf) or [*Summarize From Feedback*](https://github.com/openai/summarize-from-feedback), which should be named as ```hhrlhf``` or ```summarize_from_feedback```, and put it in the ```data/raw_data``` directory.
|
||||
4. Run the following command to preprocess the data:
|
||||
|
||||
```
|
||||
cd train/preprocess_data
|
||||
# For HH-RLHF
|
||||
cd train/hh_preprocess_data
|
||||
python step_1_process.py
|
||||
python step_2_get_train_data.py
|
||||
python step_3_get_test_data.py
|
||||
|
||||
# For Summarize From Feedback
|
||||
cd ../summarize_preprocess_data
|
||||
python step_1_process.py
|
||||
python step_2_get_train_data.py
|
||||
python step_3_get_test_data.py
|
||||
```
|
||||
|
||||
### Train
|
||||
We provide the training script for training the model. For example, you can run the following command to train the model:
|
||||
We provide the training scripts for training the model. For example, you can run the following commands to train the model:
|
||||
```
|
||||
cd train
|
||||
./train.sh [id_of_exp] train_len2 2
|
||||
|
||||
# Train LLMs with HH-RLHF
|
||||
./train_hh.sh [id_of_exp] hh_train_len2 2
|
||||
|
||||
# Train LLMs with Summarize From Feedback
|
||||
./train_summarize.sh [id_of_exp] summarize_train_len2 2
|
||||
# Length 3
|
||||
./train3_summarize.sh [id_of_exp] summarize_train_len3_alpaca 3
|
||||
```
|
||||
You can modify the ```train.sh``` to train the model with different dataset.
|
||||
|
||||
The scripts can be easily modified to train LLMs with different datasets.
|
||||
|
||||
### Test
|
||||
You can run the following command to test the model:
|
||||
The following command can be used to test the model:
|
||||
```
|
||||
cd eval
|
||||
# Test LLMs with HH-RLHF
|
||||
cd eval_hh
|
||||
./run_infer_main_dist.sh
|
||||
|
||||
# Test LLMs with Summarize From Feedback
|
||||
cd ../eval_summarize
|
||||
./run_infer_main_dist.sh
|
||||
```
|
||||
> **Note:** Before run this script, you should modify the ```infer_main_dist.sh``` to specify ```id_of_exp``` and corresponding ranking length in training.
|
||||
> **Note:** Before running, the ```id_of_exp``` and corresponding ranking length (during training) in ```run_infer_main_dist.sh``` have to be specified.
|
||||
|
||||
## Citation
|
||||
If this work is helpful to you, welcome to cite our paper as:
|
||||
|
|
|
@ -76,7 +76,7 @@ if __name__ == "__main__":
|
|||
"helpful_online.json",
|
||||
"helpful_rejection.json"
|
||||
]:
|
||||
file_path = os.path.join("..", "data", "test", file_name)
|
||||
file_path = os.path.join("..", "data", "hh_test", file_name)
|
||||
with open(file_path, "r", encoding='utf-8') as f:
|
||||
infer_data = {line_index: json.loads(l) for line_index, l in enumerate(f.readlines()) if (line_index-rank) % rank_sum == 0}
|
||||
|
|
@ -62,7 +62,6 @@ def generate_pipeline(model, tokenizer, prompts, add_special_tokens=False, gen_k
|
|||
text = text_res[index]
|
||||
assert truncated_prompts[index].rstrip() in text
|
||||
text = text.replace(truncated_prompts[index].rstrip(), "").strip()
|
||||
# text = text[prompts_size[index]:].strip()
|
||||
for stop in ["Human:", "human:", "Assistant:", "assistant:"]:
|
||||
stop_ix = text.find(stop)
|
||||
if stop_ix >= 0:
|
|
@ -21,47 +21,7 @@ def get_bleu(hyp, ref):
|
|||
ref = ref.strip()
|
||||
return nltk.translate.bleu_score.sentence_bleu([ref], hyp)
|
||||
|
||||
# Thank trlx for their helpful code:
|
||||
# https://github.com/CarperAI/trlx/blob/main/examples/hh/ppo_hh.py#L115
|
||||
def create_reward_fn_1():
|
||||
reward_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
reward_tokenizer.pad_token = reward_tokenizer.eos_token
|
||||
reward_tokenizer.truncation_side = "left"
|
||||
reward_model = TrainRewardModel("EleutherAI/gpt-j-6B", reward_tokenizer.eos_token_id)
|
||||
checkpoint = os.path.join("..", "rm", "gptj-rm-static", "hf_ckpt.pt")
|
||||
|
||||
reward_model.load_state_dict(torch.load(checkpoint))
|
||||
reward_device = "cuda:{}".format(rank)
|
||||
reward_model = reward_model.half().to(reward_device)
|
||||
reward_model.eval()
|
||||
|
||||
def get_score(prefixes, suffixes):
|
||||
# prefixes = [[p1, p1, p1], [p2, p2, p2]]
|
||||
# suffixes = [s1, s2]
|
||||
texts = []
|
||||
for p, s in zip(prefixes,suffixes):
|
||||
p = "".join(p)
|
||||
p = p.replace("<|prompter|>", "\n\nHuman: ").replace("<|assistant|>", "\n\nAssistant: ")
|
||||
texts.append(p + s + reward_tokenizer.eos_token)
|
||||
|
||||
input = reward_tokenizer(
|
||||
texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=reward_tokenizer.max_len_single_sentence,
|
||||
return_tensors="pt",
|
||||
).to(reward_device)
|
||||
|
||||
with torch.no_grad():
|
||||
rewards = reward_model(input['input_ids']) # [batch]
|
||||
|
||||
return rewards.view(-1)
|
||||
# return torch.sigmoid(rewards.view(-1))
|
||||
|
||||
return get_score, 16
|
||||
|
||||
def create_reward_fn_2():
|
||||
# model_name = "OpenAssistant/oasst-rm-2.1-pythia-1.4b-epoch-2.5"
|
||||
model_name = "OpenAssistant/oasst-rm-2-pythia-6.9b-epoch-1"
|
||||
model_device = "cuda:{}".format(rank)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
@ -3,19 +3,14 @@ export OMP_NUM_THREADS=16
|
|||
|
||||
id=$1
|
||||
ranking_len=$2
|
||||
# 30 min
|
||||
accelerate launch --config_file dp_config.yaml infer_and_eval_main_generate.py \
|
||||
--index $id \
|
||||
--stage $ranking_len > logs/generate_infer_main_${id}_${ranking_len}.log 2>&1
|
||||
|
||||
#10 min
|
||||
accelerate launch --config_file dp_config.yaml infer_and_eval_main_reward.py \
|
||||
--index $id \
|
||||
--stage $ranking_len > logs/reward_infer_main_${id}_${ranking_len}.log 2>&1
|
||||
|
||||
#1 second
|
||||
python -u infer_and_eval_main_score.py \
|
||||
--index $id \
|
||||
--stage $ranking_len > logs/score_infer_main_${id}_${ranking_len}.log 2>&1
|
||||
|
||||
# total 40 min
|
|
@ -0,0 +1,16 @@
|
|||
compute_environment: LOCAL_MACHINE
|
||||
deepspeed_config: {}
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
dynamo_backend: 'NO'
|
||||
fsdp_config: {}
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
megatron_lm_config: {}
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
use_cpu: false
|
|
@ -0,0 +1,95 @@
|
|||
#import some packages and reward funcs
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
import tqdm
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import metrics2
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
LlamaTokenizer,
|
||||
AutoModelForCausalLM
|
||||
)
|
||||
from infer_func_now import setup_seed, generate_pipeline
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import InitProcessGroupKwargs
|
||||
from datetime import timedelta
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="")
|
||||
parser.add_argument('--index', type=str)
|
||||
parser.add_argument('--stage', type=int)
|
||||
parser.add_argument('--directory', default="best_checkpoint", type=str)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400))
|
||||
accelerator = Accelerator(kwargs_handlers=[kwargs])# **accelerator_log_kwargs)
|
||||
rank = int(os.environ['RANK'])
|
||||
rank_sum = accelerator.num_processes
|
||||
model_name_or_path = os.path.join("..", "checkpoints", f"index_{args.index}", f"stage_{args.stage}", f"{args.directory}")
|
||||
model_device = "cuda:{}".format(rank)
|
||||
|
||||
model_config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=model_config, torch_dtype=torch.bfloat16).to(model_device)
|
||||
if accelerator.is_main_process:
|
||||
print(type(model))
|
||||
print(model.config)
|
||||
if model.config.architectures[0].lower() == "llamaforcausallm":
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path)
|
||||
tokenizer.unk_token = "<unk>"
|
||||
tokenizer.bos_token = "<s>"
|
||||
tokenizer.eos_token = "</s>"
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
|
||||
tokenizer.pad_token=tokenizer.eos_token,
|
||||
tokenizer.pad_token_id=tokenizer.eos_token_id,
|
||||
tokenizer.sep_token = "<sep>"
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
print(model.dtype)
|
||||
torch.cuda.empty_cache()
|
||||
model.eval()
|
||||
print(f"Rank {rank} is activated...")
|
||||
if accelerator.is_main_process:
|
||||
file_name = "test.json"
|
||||
save_path = os.path.join("inference_res/cache", "infer_generate_main_{}_{}_{}".format(args.index, args.stage, file_name))
|
||||
if os.path.exists(save_path):
|
||||
os.remove(save_path)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
file_name = "test.json"
|
||||
file_path = os.path.join("..", "data", "summarize_test", file_name)
|
||||
with open(file_path, "r", encoding='utf-8') as f:
|
||||
infer_data = {line_index: json.loads(l) for line_index, l in enumerate(f.readlines()) if (line_index-rank) % rank_sum == 0}
|
||||
|
||||
for line_index in infer_data:
|
||||
infer_data[line_index]["line_index"] = line_index
|
||||
infer_data = [infer_data[line_index] for line_index in infer_data]
|
||||
|
||||
prompts = [l['prefix'][0] for l in infer_data]
|
||||
|
||||
setup_seed()
|
||||
generated_suffixes, truncated_prompts = generate_pipeline(model, tokenizer, prompts, add_special_tokens=True)
|
||||
setup_seed()
|
||||
save_path = os.path.join("inference_res/cache", "infer_generate_main_{}_{}_{}".format(args.index, args.stage, file_name))
|
||||
|
||||
for index in range(len(infer_data)):
|
||||
infer_data[index]['infer'] = {"t": generated_suffixes[index]}
|
||||
with open(save_path, 'a', encoding='utf-8') as f:
|
||||
for line in infer_data:
|
||||
content = json.dumps(line, ensure_ascii=False)
|
||||
f.write(content+'\n')
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
print("")
|
||||
if accelerator.is_main_process:
|
||||
print("Eval on {}".format(file_name))
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
|
@ -0,0 +1,84 @@
|
|||
#import some packages and reward funcs
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
import tqdm
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import metrics2
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
LlamaTokenizer,
|
||||
AutoModelForCausalLM
|
||||
)
|
||||
from peft import PeftConfig, PeftModel
|
||||
from infer_func_now import setup_seed
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import InitProcessGroupKwargs
|
||||
from datetime import timedelta
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="")
|
||||
parser.add_argument('--index', type=str)
|
||||
parser.add_argument('--stage', type=int)
|
||||
parser.add_argument('--directory', default="best_checkpoint", type=str)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
setup_seed()
|
||||
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400))
|
||||
accelerator = Accelerator(kwargs_handlers=[kwargs])
|
||||
rank = int(os.environ['RANK'])
|
||||
rank_sum = accelerator.num_processes
|
||||
torch.cuda.empty_cache()
|
||||
print(f"Rank {rank} is activated...")
|
||||
if accelerator.is_main_process:
|
||||
file_name = "test.json"
|
||||
save_path = os.path.join("inference_res", "infer_main_{}_{}_{}".format(args.index, args.stage, file_name))
|
||||
if os.path.exists(save_path):
|
||||
os.remove(save_path)
|
||||
|
||||
save_path = os.path.join("inference_res/cache", "infer_generate_main_{}_{}_{}".format(args.index, args.stage, file_name))
|
||||
with open(save_path, 'r', encoding='utf-8') as f:
|
||||
infer_data = [json.loads(l) for l in f.readlines()]
|
||||
if "line_index" in infer_data[0]:
|
||||
infer_data = {l["line_index"]: l for l in infer_data}
|
||||
with open(save_path, 'w', encoding='utf-8') as f:
|
||||
infer_data = [infer_data[line_index] for line_index in range(len(infer_data))]
|
||||
for line in infer_data:
|
||||
content = json.dumps(line, ensure_ascii=False)
|
||||
f.write(content+'\n')
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
get_score, reward_batch_size = metrics2.create_reward_fn()
|
||||
|
||||
file_name = "test.json"
|
||||
save_path = os.path.join("inference_res/cache", "infer_generate_main_{}_{}_{}".format(args.index, args.stage, file_name))
|
||||
with open(save_path, 'r', encoding='utf-8') as f:
|
||||
infer_data = [json.loads(l) for line_index, l in enumerate(f.readlines()) if (line_index - rank) % rank_sum == 0]
|
||||
raw_prefixes = [l['prefix'][0].strip() + " " for l in infer_data]
|
||||
generated_suffixes = [l['infer']["t"].strip() for l in infer_data]
|
||||
|
||||
setup_seed()
|
||||
rewards = []
|
||||
batch_size = reward_batch_size
|
||||
for index in tqdm.tqdm(range(0,len(raw_prefixes), batch_size), desc=f"Rank {rank} rewarding..."):
|
||||
if len(raw_prefixes) - index < batch_size:
|
||||
batch_size = len(raw_prefixes) - index
|
||||
rewards.extend(torch.sigmoid(get_score(raw_prefixes[index:index+batch_size], generated_suffixes[index:index+batch_size])).cpu().detach().numpy().tolist())
|
||||
assert len(rewards) == len(generated_suffixes) and len(rewards) == len(infer_data), (len(rewards), len(generated_suffixes), len(infer_data))
|
||||
|
||||
for index in range(len(infer_data)):
|
||||
infer_data[index]["infer"]["score"] = rewards[index]
|
||||
infer_data[index]["infer"]["bleu"] = metrics2.get_bleu(infer_data[index]['infer']['t'], infer_data[index]['suffix'][0])
|
||||
|
||||
save_path = os.path.join("inference_res", "infer_main_{}_{}_{}".format(args.index, args.stage, file_name))
|
||||
with open(save_path, 'a', encoding='utf-8') as f:
|
||||
for line in infer_data:
|
||||
content = json.dumps(line, ensure_ascii=False)
|
||||
f.write(content+'\n')
|
||||
print(f"Rank {rank} completed!")
|
|
@ -0,0 +1,47 @@
|
|||
import os
|
||||
import argparse
|
||||
import json
|
||||
import tqdm
|
||||
import evaluate
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="")
|
||||
parser.add_argument('--index', type=str)
|
||||
parser.add_argument('--stage', type=int)
|
||||
parser.add_argument('--directory', default="best_checkpoint", type=str)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
file_name = "test.json"
|
||||
save_path = os.path.join("inference_res", "infer_main_{}_{}_{}".format(args.index, args.stage, file_name))
|
||||
with open(save_path, 'r', encoding='utf-8') as f:
|
||||
infer_data = [json.loads(l) for line_index, l in enumerate(f.readlines())]
|
||||
|
||||
bleu = 0
|
||||
avg_reward = 0
|
||||
predictions = []
|
||||
references = []
|
||||
|
||||
for line in infer_data:
|
||||
avg_reward += line['infer']['score']
|
||||
bleu += line['infer']['bleu']
|
||||
predictions.append(
|
||||
line['infer']["t"].strip()
|
||||
)
|
||||
references.append(
|
||||
line["suffix"][0].strip()
|
||||
)
|
||||
|
||||
rouge = evaluate.load('rouge')
|
||||
results = rouge.compute(predictions=predictions, references=references)
|
||||
bleu = bleu / len(infer_data)
|
||||
avg_reward = avg_reward / len(infer_data)
|
||||
|
||||
print("Eval on {}".format(file_name))
|
||||
print("BLEU: {}".format(bleu))
|
||||
print("Avg Reward: {}".format(avg_reward))
|
||||
for key in results:
|
||||
print("{}: {}".format(key, results[key]))
|
|
@ -0,0 +1,76 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
def setup_seed(seed=42):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.benchmark=False
|
||||
torch.backends.cudnn.deterministic=True
|
||||
|
||||
def generate_pipeline(model, tokenizer, prompts, add_special_tokens=False, gen_kwarg={"max_new_tokens": 64, "num_beams": 1, "do_sample": False,}, batch_size = 28):
|
||||
def pipeline(prompts):
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.truncation_side = "right"
|
||||
|
||||
new_prompts = []
|
||||
for p in prompts:
|
||||
assert p[-7:] == "\nTL;DR:", p[-7:]
|
||||
p = p[:-7]
|
||||
new_prompts.append(p)
|
||||
|
||||
model_inputs = tokenizer(
|
||||
new_prompts,
|
||||
max_length=512,
|
||||
truncation=True,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
truncated_prompts = tokenizer.batch_decode(model_inputs['input_ids'], skip_special_tokens=True)
|
||||
truncated_prompts = [p + "\nTL;DR:" for p in truncated_prompts]
|
||||
model_inputs = tokenizer(
|
||||
truncated_prompts,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
truncated_prompts = tokenizer.batch_decode(model_inputs['input_ids'], skip_special_tokens=True)
|
||||
prompts_size = [len(s) for s in truncated_prompts]
|
||||
return model_inputs, prompts_size, truncated_prompts
|
||||
|
||||
model_inputs, prompts_size, truncated_prompts = pipeline(prompts)
|
||||
text_res = []
|
||||
for index in tqdm.tqdm(range(0, len(model_inputs["input_ids"]), batch_size)):
|
||||
if len(model_inputs["input_ids"]) - index < batch_size:
|
||||
batch_size = len(model_inputs["input_ids"]) - index
|
||||
|
||||
batch = {key: model_inputs[key][index:index+batch_size].to(model.device) for key in model_inputs}
|
||||
with torch.no_grad():
|
||||
ts = model.generate(
|
||||
**batch,
|
||||
**gen_kwarg,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
).cpu().detach()
|
||||
text_res.append(ts)
|
||||
|
||||
for index in range(len(text_res)):
|
||||
text_res[index] = tokenizer.batch_decode(
|
||||
text_res[index],
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
text_res = sum(text_res, [])
|
||||
for index in range(len(text_res)):
|
||||
text = text_res[index]
|
||||
assert truncated_prompts[index].rstrip() in text
|
||||
text = text.replace(truncated_prompts[index].rstrip(), "").strip()
|
||||
for stop in ["\n\n"]:
|
||||
stop_ix = text.find(stop)
|
||||
if stop_ix >= 0:
|
||||
text = text[:stop_ix].rstrip()
|
||||
text_res[index] = text
|
||||
|
||||
return text_res, truncated_prompts
|
|
@ -0,0 +1,70 @@
|
|||
import sys
|
||||
sys.path.append("..")
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
|
||||
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXConfig, GPTNeoXModel, GPTNeoXPreTrainedModel
|
||||
from transformers.utils import ModelOutput
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional
|
||||
import tqdm
|
||||
import nltk
|
||||
|
||||
rank = int(os.environ['RANK'])
|
||||
|
||||
def get_bleu(hyp, ref):
|
||||
hyp = hyp.strip()
|
||||
ref = ref.strip()
|
||||
return nltk.translate.bleu_score.sentence_bleu([ref], hyp)
|
||||
|
||||
def create_reward_fn_2():
|
||||
model_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
|
||||
model_device = "cuda:{}".format(rank)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.truncation_side = "right"
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained(model_name).to(model_device)
|
||||
reward_model.eval()
|
||||
|
||||
def get_score(prefixes, suffixes):
|
||||
input_content = tokenizer(
|
||||
prefixes,
|
||||
suffixes,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=1024,
|
||||
return_tensors="pt",
|
||||
).to(model_device)
|
||||
with torch.no_grad():
|
||||
rewards = reward_model(**input_content).logits
|
||||
|
||||
return rewards.view(-1)
|
||||
|
||||
return get_score, 140
|
||||
|
||||
def create_reward_fn_3():
|
||||
model_name = "OpenAssistant/reward-model-deberta-v3-large"
|
||||
model_device = "cuda:{}".format(rank)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.truncation_side = "right"
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained(model_name).to(model_device)
|
||||
reward_model.eval()
|
||||
|
||||
def get_score(prefixes, suffixes):
|
||||
input_content = tokenizer(
|
||||
prefixes,
|
||||
suffixes,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=1024,
|
||||
return_tensors="pt",
|
||||
).to(model_device)
|
||||
with torch.no_grad():
|
||||
rewards = reward_model(**input_content).logits
|
||||
|
||||
return rewards.view(-1)
|
||||
|
||||
return get_score, 140
|
||||
|
||||
create_reward_fn = create_reward_fn_2
|
|
@ -0,0 +1,16 @@
|
|||
export PYTHONIOENCODING=utf-8
|
||||
export OMP_NUM_THREADS=16
|
||||
|
||||
id=$1
|
||||
ranking_len=$2
|
||||
accelerate launch --config_file dp_config.yaml infer_and_eval_main_generate.py \
|
||||
--index $id \
|
||||
--stage $ranking_len > logs/generate_infer_main_${id}_${ranking_len}.log 2>&1
|
||||
|
||||
accelerate launch --config_file dp_config.yaml infer_and_eval_main_reward.py \
|
||||
--index $id \
|
||||
--stage $ranking_len > logs/reward_infer_main_${id}_${ranking_len}.log 2>&1
|
||||
|
||||
python -u infer_and_eval_main_score.py \
|
||||
--index $id \
|
||||
--stage $ranking_len > logs/score_infer_main_${id}_${ranking_len}.log 2>&1
|
|
@ -8,3 +8,6 @@ torch==1.13.1+cu117
|
|||
tqdm==4.64.1
|
||||
transformers==4.28.1
|
||||
deepspeed==0.8.1
|
||||
evaluate
|
||||
rouge_score
|
||||
tensorboard
|
Before Width: | Height: | Size: 757 KiB After Width: | Height: | Size: 757 KiB |
Binary file not shown.
After Width: | Height: | Size: 66 KiB |
|
@ -0,0 +1,22 @@
|
|||
compute_environment: LOCAL_MACHINE
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 16
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
dynamo_backend: 'NO'
|
||||
fsdp_config: {}
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
megatron_lm_config: {}
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
use_cpu: false
|
|
@ -31,8 +31,6 @@ def hhrlhf_preprocess(path,filename,index_generator,split='train'):
|
|||
|
||||
dialogue: list[tuple[str, str]] = []
|
||||
|
||||
# go over messages and combine consecutive messages from the
|
||||
# same speaker (OA v1 expects alternating roles)
|
||||
role = None
|
||||
messages = []
|
||||
for line in lines:
|
||||
|
@ -72,7 +70,7 @@ def hhrlhf_preprocess(path,filename,index_generator,split='train'):
|
|||
if len(rejected) == len(chosen):
|
||||
assert chosen[-1][0] == rejected[-1][0]
|
||||
|
||||
prefix = [role+text for role, text in chosen[:-1]] # need to be concated with [EOS] in practice
|
||||
prefix = [role+text for role, text in chosen[:-1]]
|
||||
prefix.append(chosen[-1][0])
|
||||
good_reply = chosen[-1][1] # last part of dialog, the text
|
||||
bad_reply = rejected[-1][1] # last part of dialog, the text
|
||||
|
@ -112,7 +110,7 @@ def hhrlhf_preprocess(path,filename,index_generator,split='train'):
|
|||
prefix.append([role, text])
|
||||
elif role == "<|assistant|>":
|
||||
temp_prefix = [temp_role+temp_text for temp_role, temp_text in prefix]
|
||||
temp_prefix.append(role) # need to be concated with [EOS] in practice
|
||||
temp_prefix.append(role)
|
||||
temp_reply = text # last part of dialog, the text
|
||||
chosen_sample = {
|
||||
'extended':[
|
||||
|
@ -148,7 +146,7 @@ def hhrlhf_preprocess(path,filename,index_generator,split='train'):
|
|||
prefix.append([role, text])
|
||||
elif role == "<|assistant|>":
|
||||
temp_prefix = [temp_role+temp_text for temp_role, temp_text in prefix]
|
||||
temp_prefix.append(role) # need to be concated with [EOS] in practice
|
||||
temp_prefix.append(role)
|
||||
temp_reply = text # last part of dialog, the text
|
||||
rejected_sample = {
|
||||
'extended':[
|
||||
|
@ -184,56 +182,13 @@ def hhrlhf_preprocess(path,filename,index_generator,split='train'):
|
|||
return samples
|
||||
|
||||
if __name__ == "__main__":
|
||||
# get a global index generator
|
||||
global_index_generator = gen_global_index()
|
||||
|
||||
# prepare to post-processing
|
||||
res = {
|
||||
'hhrlhf':[],
|
||||
}
|
||||
|
||||
prompts = {
|
||||
'hhrlhf': '<prefix>',
|
||||
'summarize':'<prefix>',
|
||||
'webgpt':'<prefix>',
|
||||
'tldr':'<prefix>',
|
||||
}
|
||||
# process raw datasets
|
||||
# hhrlhf
|
||||
res['hhrlhf'] = [
|
||||
hhrlhf_preprocess(os.path.join('..','..','data','raw_data','hhrlhf','harmless-base'),'train.jsonl',global_index_generator,split='train'),
|
||||
# hhrlhf_preprocess(os.path.join('..','..','data','raw_data','hhrlhf','harmless-base'),'test.jsonl',global_index_generator,split='test'),
|
||||
hhrlhf_preprocess(os.path.join('..','..','data','raw_data','hhrlhf','helpful-base'),'train.jsonl',global_index_generator,split='train'),
|
||||
# hhrlhf_preprocess(os.path.join('..','..','data','raw_data','hhrlhf','helpful-base'),'test.jsonl',global_index_generator,split='test'),
|
||||
hhrlhf_preprocess(os.path.join('..','..','data','raw_data','hhrlhf','helpful-online'),'train.jsonl',global_index_generator,split='train'),
|
||||
# hhrlhf_preprocess(os.path.join('..','..','data','raw_data','hhrlhf','helpful-online'),'test.jsonl',global_index_generator,split='test'),
|
||||
hhrlhf_preprocess(os.path.join('..','..','data','raw_data','hhrlhf','helpful-rejection'),'train.jsonl',global_index_generator,split='train'),
|
||||
# hhrlhf_preprocess(os.path.join('..','..','data','raw_data','hhrlhf','helpful-rejection'),'test.jsonl',global_index_generator,split='test'),
|
||||
]
|
||||
hhrlhf_preprocess(os.path.join('..','..','data','raw_data','hhrlhf','harmless-base'),'test.jsonl',global_index_generator,split='dev')
|
||||
hhrlhf_preprocess(os.path.join('..','..','data','raw_data','hhrlhf','helpful-base'),'test.jsonl',global_index_generator,split='dev')
|
||||
hhrlhf_preprocess(os.path.join('..','..','data','raw_data','hhrlhf','helpful-online'),'test.jsonl',global_index_generator,split='dev')
|
||||
hhrlhf_preprocess(os.path.join('..','..','data','raw_data','hhrlhf','helpful-rejection'),'test.jsonl',global_index_generator,split='dev')
|
||||
|
||||
global_prefixes = []
|
||||
global_extended_samples = 0
|
||||
for key in res:
|
||||
for dataset in res[key]:
|
||||
for sample in dataset:
|
||||
for sub_sample in sample['extended']:
|
||||
prefix = "".join(sub_sample['prefix'])
|
||||
prefix = prefix.replace("<|prompter|>", "\n\nHuman: ").replace("<|assistant|>", "\n\nAssistant: ")
|
||||
prefix = prompts[key].replace('<prefix>', prefix)
|
||||
global_prefixes.append(
|
||||
{
|
||||
'id': sub_sample['id'],
|
||||
'prefix': prefix,
|
||||
'target_num': sub_sample['target_num'],
|
||||
'target': []
|
||||
}
|
||||
)
|
||||
global_extended_samples += sub_sample['target_num']
|
||||
|
||||
|
||||
print('Total Num: {}'.format(len(global_prefixes)))
|
||||
print('Total Extended Num: {}'.format(global_extended_samples))
|
|
@ -5,9 +5,8 @@ import json
|
|||
import random
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from utils.metrics import create_reward_fn_3
|
||||
get_score, reward_batch_size = create_reward_fn_3()
|
||||
# get_score = None
|
||||
from utils.metrics_hh import create_reward_fn
|
||||
get_score, reward_batch_size = create_reward_fn()
|
||||
|
||||
def split_trans(split):
|
||||
if split == 'train' or split == 'test' or split == 'dev':
|
||||
|
@ -48,7 +47,6 @@ def reward_model_ranker(prefixes, suffixes):
|
|||
|
||||
def extract_train_data(root_dir, if_score, if_rerank, training_stage_num = None, split='train'):
|
||||
file_list = []
|
||||
# for root,dirs,files in os.walk('refilled_data'):
|
||||
for root,dirs,files in os.walk(root_dir):
|
||||
for file in files:
|
||||
if not file.endswith("json"):
|
||||
|
@ -103,16 +101,19 @@ def extract_train_data(root_dir, if_score, if_rerank, training_stage_num = None,
|
|||
for l in training_data:
|
||||
l['reward'] = [1.0] * len(l['suffix'])
|
||||
|
||||
for l in training_data:
|
||||
l['sft_index'] = 0
|
||||
|
||||
return training_data
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_dir = os.path.join('..','..','data',"preprocessed_data")
|
||||
data_aug = False
|
||||
os.makedirs(os.path.join('..','..','data','train_len2'), exist_ok=True)
|
||||
os.makedirs(os.path.join('..','..','data','hh_train_len2'), exist_ok=True)
|
||||
random.seed(42)
|
||||
training_data = extract_train_data(root_dir = os.path.join(root_dir, "hhrlhf"), if_score = True, if_rerank=True, split = 'train')
|
||||
random.shuffle(training_data)
|
||||
with open(os.path.join('..','..','data','train_len2','train.json'),'w', encoding='utf-8') as f:
|
||||
with open(os.path.join('..','..','data','hh_train_len2','train.json'),'w', encoding='utf-8') as f:
|
||||
for sample in training_data:
|
||||
f.write(json.dumps(sample,ensure_ascii=False)+'\n')
|
||||
|
||||
|
@ -125,7 +126,6 @@ if __name__ == '__main__':
|
|||
helpful_base_dev_data = extract_train_data(root_dir = os.path.join(root_dir, "hhrlhf", "helpful-base"), if_score = True, if_rerank=False, split = 'dev')
|
||||
random.shuffle(helpful_base_dev_data)
|
||||
|
||||
|
||||
random.seed(42)
|
||||
helpful_online_dev_data = extract_train_data(root_dir = os.path.join(root_dir, "hhrlhf", "helpful-online"), if_score = True, if_rerank=False, split = 'dev')
|
||||
random.shuffle(helpful_online_dev_data)
|
||||
|
@ -137,6 +137,6 @@ if __name__ == '__main__':
|
|||
total_dev_data = harmless_base_dev_data + helpful_base_dev_data + helpful_online_dev_data + helpful_rejection_dev_data
|
||||
random.shuffle(total_dev_data)
|
||||
total_dev_data = total_dev_data[:280]
|
||||
with open(os.path.join('..','..','data','dev','sampled_dev.json'),'w', encoding='utf-8') as f:
|
||||
with open(os.path.join('..','..','data','hh_dev','sampled_dev.json'),'w', encoding='utf-8') as f:
|
||||
for sample in total_dev_data:
|
||||
f.write(json.dumps(sample,ensure_ascii=False)+'\n')
|
|
@ -47,7 +47,6 @@ def reward_model_ranker(prefixes, suffixes):
|
|||
|
||||
def extract_train_data(root_dir, if_score, if_rerank, training_stage_num = None, split='train'):
|
||||
file_list = []
|
||||
# for root,dirs,files in os.walk('refilled_data'):
|
||||
for root,dirs,files in os.walk(root_dir):
|
||||
for file in files:
|
||||
if not file.endswith("json"):
|
||||
|
@ -103,37 +102,40 @@ def extract_train_data(root_dir, if_score, if_rerank, training_stage_num = None,
|
|||
for l in training_data:
|
||||
l['reward'] = [1.0] * len(l['suffix'])
|
||||
|
||||
for l in training_data:
|
||||
l['sft_index'] = 0
|
||||
|
||||
return training_data
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_dir = os.path.join('..','..','data',"preprocessed_data")
|
||||
|
||||
data_aug = False
|
||||
os.makedirs(os.path.join('..','..','data','test'), exist_ok=True)
|
||||
os.makedirs(os.path.join('..','..','data','hh_test'), exist_ok=True)
|
||||
random.seed(42)
|
||||
harmless_base_dev_data = extract_train_data(root_dir = os.path.join(root_dir, "hhrlhf", "harmless-base"), if_score = True, if_rerank=True, split = 'dev')
|
||||
random.shuffle(harmless_base_dev_data)
|
||||
with open(os.path.join('..','..','data','test','harmless_base.json'),'w', encoding='utf-8') as f:
|
||||
with open(os.path.join('..','..','data','hh_test','harmless_base.json'),'w', encoding='utf-8') as f:
|
||||
for sample in harmless_base_dev_data:
|
||||
f.write(json.dumps(sample,ensure_ascii=False)+'\n')
|
||||
|
||||
random.seed(42)
|
||||
helpful_base_dev_data = extract_train_data(root_dir = os.path.join(root_dir, "hhrlhf", "helpful-base"), if_score = True, if_rerank=False, split = 'dev')
|
||||
random.shuffle(helpful_base_dev_data)
|
||||
with open(os.path.join('..','..','data','test','helpful_base.json'),'w', encoding='utf-8') as f:
|
||||
with open(os.path.join('..','..','data','hh_test','helpful_base.json'),'w', encoding='utf-8') as f:
|
||||
for sample in helpful_base_dev_data:
|
||||
f.write(json.dumps(sample,ensure_ascii=False)+'\n')
|
||||
|
||||
random.seed(42)
|
||||
helpful_online_dev_data = extract_train_data(root_dir = os.path.join(root_dir, "hhrlhf", "helpful-online"), if_score = True, if_rerank=False, split = 'dev')
|
||||
random.shuffle(helpful_online_dev_data)
|
||||
with open(os.path.join('..','..','data','test','helpful_online.json'),'w', encoding='utf-8') as f:
|
||||
with open(os.path.join('..','..','data','hh_test','helpful_online.json'),'w', encoding='utf-8') as f:
|
||||
for sample in helpful_online_dev_data:
|
||||
f.write(json.dumps(sample,ensure_ascii=False)+'\n')
|
||||
|
||||
random.seed(42)
|
||||
helpful_rejection_dev_data = extract_train_data(root_dir = os.path.join(root_dir, "hhrlhf", "helpful-rejection"), if_score = True, if_rerank=False, split = 'dev')
|
||||
random.shuffle(helpful_rejection_dev_data)
|
||||
with open(os.path.join('..','..','data','test','helpful_rejection.json'),'w', encoding='utf-8') as f:
|
||||
with open(os.path.join('..','..','data','hh_test','helpful_rejection.json'),'w', encoding='utf-8') as f:
|
||||
for sample in helpful_rejection_dev_data:
|
||||
f.write(json.dumps(sample,ensure_ascii=False)+'\n')
|
|
@ -0,0 +1,109 @@
|
|||
import json
|
||||
import re
|
||||
import pprint
|
||||
import os
|
||||
import tqdm
|
||||
import random
|
||||
random.seed(42)
|
||||
|
||||
def gen_global_index():
|
||||
index = 0
|
||||
while True:
|
||||
yield index
|
||||
index += 1
|
||||
|
||||
def split_trans(split):
|
||||
if split == 'train' or split == 'test' or split == 'dev':
|
||||
return split
|
||||
elif split == 'valid':
|
||||
return 'dev'
|
||||
elif split == 'valid1':
|
||||
return 'dev'
|
||||
elif split == 'valid2':
|
||||
return 'test'
|
||||
else:
|
||||
raise Exception('guaiguaidigai')
|
||||
|
||||
def summarize_from_feedback_preprocess(path,index_generator):
|
||||
files = os.listdir(path)
|
||||
files = [filename for filename in files if filename.endswith('.json')]
|
||||
target_samples = {
|
||||
'train':[],
|
||||
'dev':[],
|
||||
'test':[]
|
||||
}
|
||||
|
||||
for filename in files:
|
||||
with open(os.path.join(path,filename),'r', encoding="utf-8") as f:
|
||||
raw = f.readlines()
|
||||
|
||||
data = []
|
||||
for line in raw:
|
||||
line = json.loads(line)
|
||||
data.append(line)
|
||||
|
||||
samples = []
|
||||
bar = tqdm.tqdm(data)
|
||||
for index,sample in enumerate(bar):
|
||||
bar.set_description(os.path.join(path,filename))
|
||||
assert len(sample['summaries']) == 2
|
||||
if 'post' in sample['info']:
|
||||
prefix = "SUBREDDIT: r/{}\nTITLE: {}\nPOST: {}\nTL;DR:".format(sample['info']['subreddit'], sample['info']['title'],sample['info']['post']).strip()
|
||||
one_sample = {
|
||||
'available': [
|
||||
{
|
||||
'id':next(index_generator),
|
||||
'prefix': prefix,
|
||||
'target_num':2,
|
||||
'target':[
|
||||
" {}".format(sample['summaries'][sample['choice']]['text'].strip()),
|
||||
" {}".format(sample['summaries'][1-sample['choice']]['text'].strip()),
|
||||
]
|
||||
},
|
||||
],
|
||||
'split': split_trans(sample['split']),
|
||||
'source': {
|
||||
'path': os.path.join(path,filename),
|
||||
'line_num': index+1,
|
||||
}
|
||||
}
|
||||
target_samples[one_sample['split']].append(one_sample)
|
||||
else:
|
||||
prefix = "Article: {}\nTL;DR:".format(sample['info']['article'])
|
||||
pass
|
||||
|
||||
os.makedirs(path.replace('raw_data','preprocessed_data'), exist_ok=True)
|
||||
|
||||
true_dev_index = random.sample(list(range(len(target_samples['dev']))),1000)
|
||||
true_dev = []
|
||||
for index, sample in enumerate(target_samples['dev']):
|
||||
if index in true_dev_index:
|
||||
sample['split'] = 'dev'
|
||||
true_dev.append(sample)
|
||||
else:
|
||||
sample['split'] = 'train'
|
||||
target_samples['train'].append(sample)
|
||||
target_samples['dev'] = true_dev
|
||||
|
||||
with open(os.path.join(path.replace('raw_data','preprocessed_data'), "train.json"), 'w', encoding='utf-8') as f:
|
||||
for sample in target_samples['train']:
|
||||
f.write(json.dumps(sample,ensure_ascii=False)+'\n')
|
||||
print("{}: {}".format(os.path.join(path.replace('raw_data','preprocessed_data'),"train.json"),len(target_samples['train'])))
|
||||
|
||||
with open(os.path.join(path.replace('raw_data','preprocessed_data'), "dev.json"), 'w', encoding='utf-8') as f:
|
||||
for sample in target_samples['dev']:
|
||||
f.write(json.dumps(sample,ensure_ascii=False)+'\n')
|
||||
print("{}: {}".format(os.path.join(path.replace('raw_data','preprocessed_data'),"dev.json"),len(target_samples['dev'])))
|
||||
|
||||
with open(os.path.join(path.replace('raw_data','preprocessed_data'), "test.json"), 'w', encoding='utf-8') as f:
|
||||
for sample in target_samples['test']:
|
||||
f.write(json.dumps(sample,ensure_ascii=False)+'\n')
|
||||
print("{}: {}".format(os.path.join(path.replace('raw_data','preprocessed_data'),"test.json"),len(target_samples['test'])))
|
||||
|
||||
if __name__ == "__main__":
|
||||
global_index_generator = gen_global_index()
|
||||
|
||||
summarize_from_feedback_preprocess(
|
||||
os.path.join('..','..','data','raw_data','summarize_from_feedback','comparisons'),
|
||||
global_index_generator
|
||||
)
|
|
@ -0,0 +1,118 @@
|
|||
import os
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from utils.metrics_summarize import create_reward_fn
|
||||
get_score, reward_batch_size = create_reward_fn()
|
||||
|
||||
def split_trans(split):
|
||||
if split == 'train' or split == 'test' or split == 'dev':
|
||||
return split
|
||||
elif split == 'valid':
|
||||
return 'dev'
|
||||
elif split == 'valid1':
|
||||
return 'dev'
|
||||
elif split == 'valid2':
|
||||
return 'test'
|
||||
else:
|
||||
raise Exception('guaiguaidigai')
|
||||
|
||||
def concat_wo_ranker(prefixes, suffixes):
|
||||
#prefixes = [[a,b,c],[d,e,f]]
|
||||
#suffixes = [[a,b,c],[d,e,f]]
|
||||
training_stage_num = len(prefixes[0])
|
||||
batch_size = len(prefixes)
|
||||
new_prefixes = sum(prefixes,[])
|
||||
new_suffixes = sum(suffixes,[])
|
||||
rewards = get_score(new_prefixes, new_suffixes).view(batch_size, training_stage_num).cpu().detach().numpy().tolist() #[batch_size, ranking]
|
||||
|
||||
return prefixes, suffixes, rewards
|
||||
|
||||
def reward_model_ranker(prefixes, suffixes):
|
||||
#prefixes = [[a,b,c],[d,e,f]]
|
||||
#suffixes = [[a,b,c],[d,e,f]]
|
||||
training_stage_num = len(prefixes[0])
|
||||
batch_size = len(prefixes)
|
||||
new_prefixes = sum(prefixes,[])
|
||||
new_suffixes = sum(suffixes,[])
|
||||
rewards = get_score(new_prefixes, new_suffixes).view(batch_size, training_stage_num).cpu().detach().numpy() #[batch_size, ranking]
|
||||
indices = np.argsort(-rewards,axis=1)
|
||||
prefixes = [[prefixes[i][index] for index in indices[i]] for i in range(batch_size)]
|
||||
suffixes = [[suffixes[i][index] for index in indices[i]] for i in range(batch_size)]
|
||||
rewards = [[float(rewards[i][index]) for index in indices[i]] for i in range(batch_size)]
|
||||
return prefixes, suffixes, rewards
|
||||
|
||||
def extract_train_data(root_dir, if_score, if_rerank, training_stage_num = None, split='train'):
|
||||
training_data = []
|
||||
with open(root_dir, 'r', encoding='utf-8') as f:
|
||||
raw_data = f.readlines()
|
||||
for line in raw_data:
|
||||
sample = json.loads(line)
|
||||
if split_trans(sample['split']) == split:
|
||||
new_sample = {'meta': sample['source'], 'prefix':[],'suffix':[]}
|
||||
if data_aug:
|
||||
for s in sample['extended']+sample['available']:
|
||||
for suffix in s['target']:
|
||||
assert isinstance(suffix,str)
|
||||
new_sample['prefix'].append(s['prefix'])
|
||||
new_sample['suffix'].append(suffix)
|
||||
else:
|
||||
for s in sample['available']:
|
||||
for suffix in s['target']:
|
||||
assert isinstance(suffix,str)
|
||||
new_sample['prefix'].append(s['prefix'])
|
||||
new_sample['suffix'].append(suffix)
|
||||
training_data.append(new_sample)
|
||||
if training_stage_num == None:
|
||||
training_stage_num = len(new_sample['prefix'])
|
||||
assert training_stage_num == len(new_sample['prefix'])
|
||||
|
||||
if if_score:
|
||||
batch_size = reward_batch_size / 2 # default
|
||||
for index in tqdm.tqdm(range(0,len(training_data),batch_size),desc="rewarding"):
|
||||
prefixes = []
|
||||
suffixes = []
|
||||
if len(training_data)-index < batch_size:
|
||||
batch_size = len(training_data)-index
|
||||
for sub_index in range(batch_size):
|
||||
prefixes.append(training_data[index+sub_index]['prefix'])
|
||||
suffixes.append(training_data[index+sub_index]['suffix'])
|
||||
if if_rerank:
|
||||
prefixes, suffixes, rewards = reward_model_ranker(prefixes,suffixes)
|
||||
else:
|
||||
prefixes, suffixes, rewards = concat_wo_ranker(prefixes,suffixes)
|
||||
for sub_index in range(batch_size):
|
||||
training_data[index+sub_index]['prefix'] = prefixes[sub_index]
|
||||
training_data[index+sub_index]['suffix'] = suffixes[sub_index]
|
||||
training_data[index+sub_index]['reward'] = rewards[sub_index]
|
||||
else:
|
||||
for l in training_data:
|
||||
l['reward'] = [1.0] * len(l['suffix'])
|
||||
|
||||
for l in training_data:
|
||||
l['sft_index'] = 0
|
||||
|
||||
return training_data
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_dir = os.path.join('..','..','data',"preprocessed_data", "summarize_from_feedback", "comparisons")
|
||||
data_aug = False
|
||||
os.makedirs(os.path.join('..','..','data','summarize_train_len2'), exist_ok=True)
|
||||
random.seed(42)
|
||||
training_data = extract_train_data(root_dir = os.path.join(root_dir, "train.json"), if_score = True, if_rerank=True, split = 'train')
|
||||
random.shuffle(training_data)
|
||||
with open(os.path.join('..','..','data','summarize_train_len2','train.json'),'a', encoding='utf-8') as f:
|
||||
for sample in training_data:
|
||||
f.write(json.dumps(sample,ensure_ascii=False)+'\n')
|
||||
|
||||
data_aug = False
|
||||
os.makedirs(os.path.join('..','..','data','summarize_dev'), exist_ok=True)
|
||||
random.seed(42)
|
||||
total_dev_data = extract_train_data(root_dir = os.path.join(root_dir, "dev.json"), if_score = True, if_rerank=False, split = 'dev')
|
||||
random.shuffle(total_dev_data)
|
||||
with open(os.path.join('..','..','data','summarize_dev','sampled_dev.json'),'a', encoding='utf-8') as f:
|
||||
for sample in total_dev_data:
|
||||
f.write(json.dumps(sample,ensure_ascii=False)+'\n')
|
|
@ -0,0 +1,109 @@
|
|||
import os
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from utils.metrics import create_reward_fn_2
|
||||
get_score, reward_batch_size = create_reward_fn_2()
|
||||
|
||||
def split_trans(split):
|
||||
if split == 'train' or split == 'test' or split == 'dev':
|
||||
return split
|
||||
elif split == 'valid':
|
||||
return 'dev'
|
||||
elif split == 'valid1':
|
||||
return 'dev'
|
||||
elif split == 'valid2':
|
||||
return 'test'
|
||||
else:
|
||||
raise Exception('guaiguaidigai')
|
||||
|
||||
def concat_wo_ranker(prefixes, suffixes):
|
||||
#prefixes = [[a,b,c],[d,e,f]]
|
||||
#suffixes = [[a,b,c],[d,e,f]]
|
||||
training_stage_num = len(prefixes[0])
|
||||
batch_size = len(prefixes)
|
||||
new_prefixes = sum(prefixes,[])
|
||||
new_suffixes = sum(suffixes,[])
|
||||
rewards = get_score(new_prefixes, new_suffixes).view(batch_size, training_stage_num).cpu().detach().numpy().tolist() #[batch_size, ranking]
|
||||
|
||||
return prefixes, suffixes, rewards
|
||||
|
||||
def reward_model_ranker(prefixes, suffixes):
|
||||
#prefixes = [[a,b,c],[d,e,f]]
|
||||
#suffixes = [[a,b,c],[d,e,f]]
|
||||
training_stage_num = len(prefixes[0])
|
||||
batch_size = len(prefixes)
|
||||
new_prefixes = sum(prefixes,[])
|
||||
new_suffixes = sum(suffixes,[])
|
||||
rewards = get_score(new_prefixes, new_suffixes).view(batch_size, training_stage_num).cpu().detach().numpy() #[batch_size, ranking]
|
||||
indices = np.argsort(-rewards,axis=1)
|
||||
prefixes = [[prefixes[i][index] for index in indices[i]] for i in range(batch_size)]
|
||||
suffixes = [[suffixes[i][index] for index in indices[i]] for i in range(batch_size)]
|
||||
rewards = [[float(rewards[i][index]) for index in indices[i]] for i in range(batch_size)]
|
||||
return prefixes, suffixes, rewards
|
||||
|
||||
def extract_train_data(root_dir, if_score, if_rerank, training_stage_num = None, split='train'):
|
||||
training_data = []
|
||||
with open(root_dir, 'r', encoding='utf-8') as f:
|
||||
raw_data = f.readlines()
|
||||
for line in raw_data:
|
||||
sample = json.loads(line)
|
||||
if split_trans(sample['split']) == split:
|
||||
new_sample = {'meta': sample['source'], 'prefix':[],'suffix':[]}
|
||||
if data_aug:
|
||||
for s in sample['extended']+sample['available']:
|
||||
for suffix in s['target']:
|
||||
assert isinstance(suffix,str)
|
||||
new_sample['prefix'].append(s['prefix'])
|
||||
new_sample['suffix'].append(suffix)
|
||||
else:
|
||||
for s in sample['available']:
|
||||
for suffix in s['target']:
|
||||
assert isinstance(suffix,str)
|
||||
new_sample['prefix'].append(s['prefix'])
|
||||
new_sample['suffix'].append(suffix)
|
||||
training_data.append(new_sample)
|
||||
if training_stage_num == None:
|
||||
training_stage_num = len(new_sample['prefix'])
|
||||
assert training_stage_num == len(new_sample['prefix'])
|
||||
|
||||
if if_score:
|
||||
batch_size = reward_batch_size / 2 # default
|
||||
for index in tqdm.tqdm(range(0,len(training_data),batch_size),desc="rewarding"):
|
||||
prefixes = []
|
||||
suffixes = []
|
||||
if len(training_data)-index < batch_size:
|
||||
batch_size = len(training_data)-index
|
||||
for sub_index in range(batch_size):
|
||||
prefixes.append(training_data[index+sub_index]['prefix'])
|
||||
suffixes.append(training_data[index+sub_index]['suffix'])
|
||||
if if_rerank:
|
||||
prefixes, suffixes, rewards = reward_model_ranker(prefixes,suffixes)
|
||||
else:
|
||||
prefixes, suffixes, rewards = concat_wo_ranker(prefixes,suffixes)
|
||||
for sub_index in range(batch_size):
|
||||
training_data[index+sub_index]['prefix'] = prefixes[sub_index]
|
||||
training_data[index+sub_index]['suffix'] = suffixes[sub_index]
|
||||
training_data[index+sub_index]['reward'] = rewards[sub_index]
|
||||
else:
|
||||
for l in training_data:
|
||||
l['reward'] = [1.0] * len(l['suffix'])
|
||||
|
||||
for l in training_data:
|
||||
l['sft_index'] = 0
|
||||
|
||||
return training_data
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_dir = os.path.join('..','..','data',"preprocessed_data", "summarize_from_feedback", "comparisons")
|
||||
data_aug = False
|
||||
os.makedirs(os.path.join('..','..','data','summarize_test'), exist_ok=True)
|
||||
random.seed(42)
|
||||
test_data = extract_train_data(root_dir = os.path.join(root_dir, "test.json"), if_score = True, if_rerank=False, split = 'test')
|
||||
random.shuffle(test_data)
|
||||
with open(os.path.join('..','..','data','summarize_test','test.json'),'w', encoding='utf-8') as f:
|
||||
for sample in test_data:
|
||||
f.write(json.dumps(sample,ensure_ascii=False)+'\n')
|
|
@ -0,0 +1,28 @@
|
|||
export OMP_NUM_THREADS=16
|
||||
root_dir=..
|
||||
|
||||
#stage 23
|
||||
id=$1
|
||||
data_path=$2
|
||||
ranking_len=$3
|
||||
mkdir -p $root_dir/logs/$id/$ranking_len
|
||||
accelerate launch --num_processes 7 --config_file ds_config2.yaml main.py \
|
||||
--task summarize \
|
||||
--train_file_path $root_dir/data/${data_path} \
|
||||
--validation_file_path $root_dir/data/summarize_dev \
|
||||
--validation_file_name sampled_dev.json \
|
||||
--output_dir $root_dir/checkpoints/index_$id/stage_$ranking_len \
|
||||
--log_path $root_dir/logs/$id/$ranking_len \
|
||||
--index $id \
|
||||
--seed 42 \
|
||||
--temperature 1 \
|
||||
--sft_weight 0.05 \
|
||||
--num_train_epochs 2 \
|
||||
--training_stage_num $ranking_len \
|
||||
--block_size 720 \
|
||||
--learning_rate 5e-6 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 28 \
|
||||
--model_name_or_path decapoda-research/llama-7b-hf \
|
||||
--do_train \
|
||||
--do_validation > $root_dir/logs/$id/$ranking_len/train_detail.log 2>&1
|
|
@ -7,8 +7,9 @@ data_path=$2
|
|||
ranking_len=$3
|
||||
mkdir -p $root_dir/logs/$id/$ranking_len
|
||||
accelerate launch --num_processes 7 --config_file ds_config.yaml main.py \
|
||||
--task hh \
|
||||
--train_file_path $root_dir/data/${data_path} \
|
||||
--validation_file_path $root_dir/data/dev \
|
||||
--validation_file_path $root_dir/data/hh_dev \
|
||||
--validation_file_name sampled_dev.json \
|
||||
--output_dir $root_dir/checkpoints/index_$id/stage_$ranking_len \
|
||||
--log_path $root_dir/logs/$id/$ranking_len \
|
|
@ -0,0 +1,28 @@
|
|||
export OMP_NUM_THREADS=16
|
||||
root_dir=..
|
||||
|
||||
#stage 23
|
||||
id=$1
|
||||
data_path=$2
|
||||
ranking_len=$3
|
||||
mkdir -p $root_dir/logs/$id/$ranking_len
|
||||
accelerate launch --num_processes 7 --config_file ds_config.yaml main.py \
|
||||
--task summarize \
|
||||
--train_file_path $root_dir/data/${data_path} \
|
||||
--validation_file_path $root_dir/data/summarize_dev \
|
||||
--validation_file_name sampled_dev.json \
|
||||
--output_dir $root_dir/checkpoints/index_$id/stage_$ranking_len \
|
||||
--log_path $root_dir/logs/$id/$ranking_len \
|
||||
--index $id \
|
||||
--seed 42 \
|
||||
--temperature 1 \
|
||||
--sft_weight 0.05 \
|
||||
--num_train_epochs 2 \
|
||||
--training_stage_num $ranking_len \
|
||||
--block_size 720 \
|
||||
--learning_rate 5e-6 \
|
||||
--per_device_train_batch_size 2 \
|
||||
--per_device_eval_batch_size 28 \
|
||||
--model_name_or_path decapoda-research/llama-7b-hf \
|
||||
--do_train \
|
||||
--do_validation > $root_dir/logs/$id/$ranking_len/train_detail.log 2>&1
|
|
@ -6,6 +6,11 @@ from transformers import SchedulerType
|
|||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Preference Ranking Optimization For Human Alignment")
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="hh",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_train",
|
||||
action="store_true",
|
||||
|
@ -73,7 +78,7 @@ def parse_args():
|
|||
type=int,
|
||||
default=20,
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=1")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
|
|
|
@ -16,7 +16,7 @@ from transformers import (
|
|||
DataCollatorWithPadding,
|
||||
)
|
||||
|
||||
class DataManager():
|
||||
class HH_DataManager():
|
||||
def __init__(self, config, training_stage, tokenizer_path = args.model_name_or_path):
|
||||
self.config = config
|
||||
if self.config.architectures[0].lower() == "llamaforcausallm":
|
||||
|
@ -49,7 +49,7 @@ class DataManager():
|
|||
|
||||
def train_data_collator(self, features):
|
||||
samples_num = len(features)
|
||||
training_stage = self.training_stage #len(features[0]['input_ids'])
|
||||
training_stage = self.training_stage
|
||||
origin_state = (self.tokenizer.padding_side, self.tokenizer.truncation_side)
|
||||
|
||||
self.tokenizer.truncation_side = "left"
|
||||
|
@ -83,7 +83,6 @@ class DataManager():
|
|||
)['input_ids']
|
||||
ps_lens = [len(p_input_ids)-1 for p_input_ids in ps_input_ids]
|
||||
|
||||
# assemble
|
||||
self.tokenizer.padding_side = "right"
|
||||
self.tokenizer.truncation_side = "right"
|
||||
|
||||
|
@ -100,7 +99,6 @@ class DataManager():
|
|||
return_tensors = self.return_tensors,
|
||||
)
|
||||
|
||||
#prepare prefix_mask
|
||||
seq_len = batch["attention_mask"].shape[1]
|
||||
prefix_mask = []
|
||||
for p_len in ps_lens:
|
||||
|
@ -168,7 +166,6 @@ class DataManager():
|
|||
max_length = self.max_length - 128,
|
||||
truncation = True,
|
||||
add_special_tokens = self.add_special_tokens,
|
||||
# pad_to_multiple_of = self.pad_to_multiple_of,
|
||||
return_tensors = self.return_tensors,
|
||||
).to(model.device)
|
||||
batch_size = len(prefixes)
|
||||
|
@ -182,7 +179,170 @@ class DataManager():
|
|||
num_beams=1,
|
||||
do_sample=False,
|
||||
num_return_sequences = 1,
|
||||
) #tensor
|
||||
)
|
||||
|
||||
instant_text = self.batch_decode(predicted_sents)
|
||||
|
||||
# restore states
|
||||
self.tokenizer.padding_side, self.tokenizer.truncation_side = origin_state
|
||||
|
||||
for index in range(len(instant_text)):
|
||||
assert truncated_prefixes[index].rstrip() in instant_text[index], (truncated_prefixes[index].strip(), instant_text[index])
|
||||
instant_text[index] = instant_text[index].replace(truncated_prefixes[index].rstrip(), "").strip()
|
||||
instant_text[index] = self.early_truncation(instant_text[index])
|
||||
|
||||
return instant_text
|
||||
|
||||
class Summarize_DataManager():
|
||||
def __init__(self, config, training_stage, tokenizer_path = args.model_name_or_path):
|
||||
self.config = config
|
||||
if self.config.architectures[0].lower() == "llamaforcausallm":
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path, use_fast=False)
|
||||
self.tokenizer.unk_token = "<unk>"
|
||||
self.tokenizer.bos_token = "<s>"
|
||||
self.tokenizer.eos_token = "</s>"
|
||||
else:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False)
|
||||
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.padding = True
|
||||
self.max_length = args.block_size
|
||||
self.pad_to_multiple_of = 8
|
||||
self.return_tensors = "pt"
|
||||
self.add_special_tokens = True
|
||||
self.training_stage = training_stage
|
||||
self.stop_sequences = ["\n\n"]
|
||||
|
||||
def batch_decode(self, model_output):
|
||||
# model_output = [batch, seq_len]
|
||||
return self.tokenizer.batch_decode(model_output, skip_special_tokens=True)
|
||||
|
||||
def early_truncation(self, text):
|
||||
for stop in self.stop_sequences:
|
||||
stop_ix = text.find(stop)
|
||||
if stop_ix >= 0:
|
||||
text = text[:stop_ix].strip()
|
||||
return text.strip()
|
||||
|
||||
def train_data_collator(self, features):
|
||||
samples_num = len(features)
|
||||
training_stage = self.training_stage
|
||||
origin_state = (self.tokenizer.padding_side, self.tokenizer.truncation_side)
|
||||
|
||||
self.tokenizer.truncation_side = "right"
|
||||
ps = []
|
||||
ss = []
|
||||
rs = []
|
||||
sft_index = []
|
||||
for feature_index, feature in enumerate(features):
|
||||
for p, s, r in zip(feature['prefix'][:training_stage], feature['suffix'][:training_stage], feature['reward'][:training_stage]):
|
||||
ps.append(p)
|
||||
ss.append(s)
|
||||
rs.append(r)
|
||||
assert feature["sft_index"] < training_stage
|
||||
sft_index.append(feature["sft_index"])
|
||||
|
||||
ps_input_ids = self.tokenizer(
|
||||
ps,
|
||||
add_special_tokens = self.add_special_tokens,
|
||||
)['input_ids']
|
||||
ps_lens = [len(p_input_ids)-1 for p_input_ids in ps_input_ids]
|
||||
|
||||
self.tokenizer.padding_side = "right"
|
||||
self.tokenizer.truncation_side = "right"
|
||||
|
||||
texts = []
|
||||
for p, s in zip(ps, ss):
|
||||
texts.append(p + s)
|
||||
|
||||
batch = self.tokenizer(
|
||||
texts,
|
||||
padding=self.padding,
|
||||
max_length = self.max_length,
|
||||
truncation = True,
|
||||
add_special_tokens = self.add_special_tokens,
|
||||
return_tensors = self.return_tensors,
|
||||
)
|
||||
|
||||
seq_len = batch["attention_mask"].shape[1]
|
||||
prefix_mask = []
|
||||
for p_len in ps_lens:
|
||||
assert seq_len > p_len
|
||||
prefix_mask.append(
|
||||
[1 if i<p_len else 0 for i in range(seq_len)]
|
||||
)
|
||||
batch["prefix_mask"] = torch.tensor(prefix_mask)
|
||||
|
||||
batch['labels'] = batch["input_ids"].clone().detach()
|
||||
for key in batch:
|
||||
batch[key] = batch[key].view(samples_num,training_stage,-1)
|
||||
|
||||
batch['rewards'] = torch.tensor(rs).view(samples_num, -1)
|
||||
batch['sft_index'] = torch.tensor(sft_index) # [batch]
|
||||
# restore states
|
||||
self.tokenizer.padding_side, self.tokenizer.truncation_side = origin_state
|
||||
|
||||
return batch
|
||||
|
||||
def load_train_data(
|
||||
self,
|
||||
data_collator,
|
||||
data_file_path,
|
||||
data_file_name=None,
|
||||
extension='json',
|
||||
stream = None,
|
||||
):
|
||||
raw_datasets = load_dataset(extension, data_dir = data_file_path, data_files = data_file_name, streaming=True if stream != None else False, split="train")
|
||||
|
||||
dataloader = DataLoader(
|
||||
raw_datasets,
|
||||
shuffle=True,
|
||||
collate_fn=data_collator,
|
||||
batch_size=args.per_device_train_batch_size
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
def infer_generate(self, model, prefixes):
|
||||
# prefixes = [prefix, prefix]
|
||||
origin_state = (self.tokenizer.padding_side, self.tokenizer.truncation_side)
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.tokenizer.truncation_side = "right"
|
||||
|
||||
new_prefixes = []
|
||||
for p in prefixes:
|
||||
assert p[-7:] == "\nTL;DR:", p[-7:]
|
||||
p = p[:-7]
|
||||
new_prefixes.append(p)
|
||||
|
||||
new_prefixes = self.batch_decode(
|
||||
self.tokenizer(
|
||||
new_prefixes,
|
||||
max_length = 512,
|
||||
truncation = True,
|
||||
add_special_tokens = self.add_special_tokens,
|
||||
)["input_ids"]
|
||||
)
|
||||
prefixes = [p + "\nTL;DR:" for p in new_prefixes]
|
||||
|
||||
batch = self.tokenizer(
|
||||
prefixes,
|
||||
padding=self.padding,
|
||||
add_special_tokens = self.add_special_tokens,
|
||||
return_tensors = self.return_tensors,
|
||||
).to(model.device)
|
||||
batch_size = len(prefixes)
|
||||
truncated_prefixes = self.batch_decode(batch['input_ids'])
|
||||
|
||||
with torch.no_grad():
|
||||
predicted_sents = model.generate(
|
||||
**batch,
|
||||
max_new_tokens = 64,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
num_return_sequences = 1,
|
||||
)
|
||||
|
||||
instant_text = self.batch_decode(predicted_sents)
|
||||
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
import sys
|
||||
sys.path.append("..")
|
||||
import os
|
||||
os.environ["TRANSFORMERS_CACHE"] = os.path.join("..","..","transformers_cache","models")
|
||||
os.environ["HF_HOME"] = os.path.join("..","..","transformers_cache","datasets")
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from dataclasses import dataclass
|
||||
import nltk
|
||||
|
||||
def get_bleu(hyp, ref):
|
||||
hyp = hyp.strip()
|
||||
ref = ref.strip()
|
||||
return nltk.translate.bleu_score.sentence_bleu([ref], hyp)
|
||||
|
||||
def create_reward_fn_2():
|
||||
model_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
|
||||
model_device = "cuda:{}".format(torch.cuda.device_count() - 1)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.truncation_side = "right"
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained(model_name).to(model_device)
|
||||
reward_model.eval()
|
||||
|
||||
def get_score(prefixes, suffixes):
|
||||
input_content = tokenizer(
|
||||
prefixes,
|
||||
suffixes,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=1024,
|
||||
return_tensors="pt",
|
||||
).to(model_device)
|
||||
with torch.no_grad():
|
||||
rewards = reward_model(**input_content).logits
|
||||
|
||||
return rewards.view(-1)
|
||||
|
||||
return get_score, 140
|
||||
|
||||
def create_reward_fn_3():
|
||||
model_name = "OpenAssistant/reward-model-deberta-v3-large"
|
||||
model_device = "cuda:{}".format(torch.cuda.device_count() - 1)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.truncation_side = "right"
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained(model_name).to(model_device)
|
||||
reward_model.eval()
|
||||
|
||||
def get_score(prefixes, suffixes):
|
||||
input_content = tokenizer(
|
||||
prefixes,
|
||||
suffixes,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=1024,
|
||||
return_tensors="pt",
|
||||
).to(model_device)
|
||||
with torch.no_grad():
|
||||
rewards = reward_model(**input_content).logits
|
||||
|
||||
return rewards.view(-1)
|
||||
|
||||
return get_score, 140
|
||||
|
||||
create_reward_fn = create_reward_fn_3
|
|
@ -7,7 +7,7 @@ from tqdm import tqdm
|
|||
import numpy as np
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from .data_manager import DataManager
|
||||
from .data_manager import HH_DataManager, Summarize_DataManager
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.nn.functional as F
|
||||
from transformers import (
|
||||
|
@ -16,7 +16,12 @@ from transformers import (
|
|||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
)
|
||||
from utils.metrics import create_reward_fn
|
||||
if args.task == "hh":
|
||||
from utils.metrics_hh import create_reward_fn
|
||||
elif args.task == "summarize":
|
||||
from utils.metrics_summarize import create_reward_fn
|
||||
else:
|
||||
raise ValueError("Invalid task name!")
|
||||
|
||||
class ProcessManager():
|
||||
def __init__(
|
||||
|
@ -31,10 +36,18 @@ class ProcessManager():
|
|||
self.model_config = AutoConfig.from_pretrained(self.model_path)
|
||||
|
||||
# set datamanager
|
||||
self.data_manager = DataManager(
|
||||
if args.task == "hh":
|
||||
self.data_manager = HH_DataManager(
|
||||
self.model_config,
|
||||
args.training_stage_num,
|
||||
)
|
||||
elif args.task == "summarize":
|
||||
self.data_manager = Summarize_DataManager(
|
||||
self.model_config,
|
||||
args.training_stage_num,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid task name!")
|
||||
|
||||
# set model
|
||||
self.model = AutoModelForCausalLM.from_pretrained(self.model_path,config=self.model_config)
|
||||
|
@ -321,20 +334,10 @@ class ProcessManager():
|
|||
with open(os.path.join(infer_file_path, infer_file_name), "r", encoding='utf-8') as f:
|
||||
infer_data = [json.loads(l) for l in f.readlines()]
|
||||
|
||||
is_dev = True
|
||||
if isinstance(infer_data[0]['prefix'][0],str):
|
||||
is_dev = True
|
||||
else:
|
||||
is_dev = False
|
||||
|
||||
# sort
|
||||
length = []
|
||||
for l in infer_data:
|
||||
lens = 0
|
||||
if is_dev:
|
||||
for p in l['prefix']:
|
||||
lens += (len(p.split(" ")))
|
||||
else:
|
||||
for p in l['prefix'][0]:
|
||||
lens += (len(p.split(" ")))
|
||||
length.append(lens)
|
||||
|
@ -349,11 +352,7 @@ class ProcessManager():
|
|||
if len(infer_data)-sample_index < infer_batch_size:
|
||||
infer_batch_size = len(infer_data)-sample_index
|
||||
|
||||
if is_dev:
|
||||
prefixes = [l['prefix'] for l in infer_data[sample_index:sample_index+infer_batch_size]]
|
||||
else:
|
||||
prefixes = [l['prefix'][0] for l in infer_data[sample_index:sample_index+infer_batch_size]]
|
||||
|
||||
suffixes = self.data_manager.infer_generate(model, prefixes)
|
||||
for l, s in zip(infer_data[sample_index:sample_index+infer_batch_size], suffixes):
|
||||
l['infer'] = {"t": s}
|
||||
|
|
Loading…
Reference in New Issue