add: MaskedThought

This commit is contained in:
言斯 2024-04-24 15:06:47 +08:00
parent b9a2b64479
commit f9e562d055
89 changed files with 22157 additions and 0 deletions

160
MaskedThought/MAmmoTH/.gitignore vendored Normal file
View File

@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

View File

@ -0,0 +1,281 @@
# **MAmmoTH** 🦣
This repo contains the code, data, and models for "[MAmmoTH: Building Math Generalist Models through Hybrid Instruction Tuning](https://arxiv.org/pdf/2309.05653.pdf)"
<div align="center">
🔥 🔥 🔥 Check out our <a href = "https://tiger-ai-lab.github.io/MAmmoTH/">[Project Page]</a> for more results and analysis!
</div>
<br>
<div align="center">
<img src="mammoth_github.png" width="80%" title="Introduction Figure">
</div>
### Datasets and Models
Our dataset and models are all available at Huggingface.
🤗 [MathInstruct Dataset](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
| | Base Model: Llama-2 | Base Model: Code Llama | Base Model: Mistral |
|----- |--------------------------------------------------------------- |--------------------------------------------------------------------------- |---------------------|
| 7B | 🦣 [MAmmoTH-7B](https://huggingface.co/TIGER-Lab/MAmmoTH-7B) | 🦣 [MAmmoTH-Coder-7B](https://huggingface.co/TIGER-Lab/MAmmoTH-Coder-7B) | 🦣 [MAmmoTH-7B-Mistral](https://huggingface.co/TIGER-Lab/MAmmoTH-7B-Mistral) |
| 13B | 🦣 [MAmmoTH-13B](https://huggingface.co/TIGER-Lab/MAmmoTH-13B) | 🦣 [MAmmoTH-Coder-13B](https://huggingface.co/TIGER-Lab/MAmmoTH-Coder-13B) | |
| 34B | - | 🦣 [MAmmoTH-Coder-34B](https://huggingface.co/TIGER-Lab/MAmmoTH-Coder-34B) | |
| 70B | 🦣 [MAmmoTH-70B](https://huggingface.co/TIGER-Lab/MAmmoTH-70B) | - | |
## **What's New?**
- [Dec. 4] We add the training and evaluation of MAmmoTH-7B-Mistral, which improves significantly over the LLaMA-2 version. We also have better support for vllm.
- [Oct. 10] We update our decoding method to hybrid decoding: first try PoT to generate a program, if it is not excutable, we will regenerate a CoT solution as the final answer. This hybrid decoding method improves the peformance significantly. Check our updated paper Appendix for more details.
## Highlights
We demonstrate the results of our small MAmmoTH-7B-Mistral as follows:
| **Model** | **Decoding** | **GSM** | **MATH** | **MMLU-Math** |
|---------------------------|---------------|-----------|-----------|-----------|
| MAmmoTH-7B | **Hybrid** | 53.6 | 31.5 | 44.5 |
| MAmmoTH-Coder-7B | **Hybrid** | 59.4 | 33.4 | 47.2 |
| MetaMath-7B-Mistral | **CoT** | 77.7 | 28.2 | 49.3 |
| OpenChat-3.5-7B | **CoT** | 77.3 | 28.6 | 49.6 |
| ChatGLM-3-6B | **CoT** | 72.3 | 25.7 | 45.6 |
| DeepSeek-Coder-34B | **PoT** | 58.2 | 35.3 | 46.5 |
| Grok-1 | **CoT** | 62.9 | 15.7 | - |
| QWen-72B | **CoT** | 78.9 | 35.2 | - |
| DeepSeek-67B-Chat | **CoT** | **84.1** | 32.6 | - |
| MAmmoTH-7B-Mistral | **Hybrid** | 75.0 | **40.0** | **52.5** |
## **Table of Contents**
- [📌 Introduction](#introduction)
- [⚙️ Installation](#installation)
- [🛠️ Training and Inference](#training-and-inference)
- [📜 License](#license)
- [📖 Citation](#citation)
## **Introduction**
We introduce MAmmoTH 🦣, a series of open-source large language models (LLMs) specifically tailored for general math problem-solving. The MAmmoTH models are trained on MathInstruct, a meticulously curated instruction tuning dataset that is lightweight yet generalizable. MathInstruct is compiled from 13 math rationale datasets, six of which are newly curated by this work. It uniquely focuses on the hybrid use of chain-of-thought (CoT) and program-of-thought (PoT) rationales, and ensures extensive coverage of diverse mathematical fields.
## **Installation**
Clone this repository and install the required packages:
```bash
git clone https://github.com/TIGER-AI-Lab/MAmmoTH.git
cd MAmmoTH
pip install -r requirements.txt
```
## **Training and Inference**
### **Data Loading**
Run the following command to preprocess the data:
```python
from datasets import load_dataset
dataset = load_dataset("TIGER-Lab/MathInstruct")
```
### **Quick Start**
To play with our model, run:
```python
from transformers import pipeline
pipeline = pipeline("text-generation", "TIGER-Lab/MAmmoTH-Coder-7B")
alpaca_template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\n{query}\n\n### Response:"
query = "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"
### By default, MAmmoTH will output the Chain-of-thought (CoT) rationale
rationale_prefix = ""
### You can let MAmmoTH output Program-of-thought (PoT) rationale by simply adding
rationale_prefix = " Let's write a program."
input = alpaca_template.format(query = query + rationale_prefix)
output = pipeline(input)[0]['generated_text']
print(output)
```
### **Large-scale Evaluation**
To replicate the experimental results in our paper, run:
```bash
### For open-eneded questions, the dataset should be one of
### ['gsm8k', 'svamp', 'math', 'numglue', 'deepmind', 'simuleq']
### We first try PoT and if the generated program is not executable, we shift to CoT
dataset='math'
python run_open.py \
--model "TIGER-Lab/MAmmoTH-7B-Mistral" \
--shots 0 \
--stem_flan_type "pot_prompt" \
--batch_size 8 \
--dataset $dataset \
--model_max_length 1500 \
--cot_backup \
--print \
--use_vllm
```
If you want to run self-consistency with PoT/CoT with 10 ensembles.
```bash
### For open-eneded questions, the dataset should be one of
### ['gsm8k', 'svamp', 'math', 'numglue', 'deepmind', 'simuleq']
### We first try PoT and if the generated program is not executable, we shift to CoT
dataset='gsm8k'
python run_open_sc.py \
--model "TIGER-Lab/MAmmoTH-7B-Mistral" \
--shots 0 \
--stem_flan_type "pot_prompt" \
--batch_size 8 \
--dataset $dataset \
--model_max_length 1500 \
--num_samples 10 \
--print
```
```bash
### For mutilple-choice questions, the dataset should be one of
### ['aqua', 'sat', 'mmlu_mathematics'].
### We first try PoT and if the generated program is not executable, we shift to CoT
dataset='aqua'
python run_choice.py \
--model "TIGER-Lab/MAmmoTH-7B-Mistral" \
--shots 0 \
--stem_flan_type "pot_prompt" \
--batch_size 8 \
--dataset $dataset \
--cot_backup \
--print
```
### **Fine-tuning**
To train the 7B/13B model, run:
```bash
torchrun --nproc_per_node [$WORKER_GPU] \
--master_addr [$WORKER_0_HOST] \
--node_rank [$ROLE_INDEX] \
--master_port [$WORKER_0_PORT] \
--nnodes [$WORKER_NUM] \
train.py \
--model_name_or_path "codellama/CodeLlama-7b-hf" \
--data_path "TIGER-Lab/MathInstruct" \
--bf16 True \
--output_dir checkpoints/MAmmoTH-Coder-7B \
--num_train_epochs 3 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 2000\
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True
```
To train the 34B/70B model, run:
```bash
torchrun --nproc_per_node [$WORKER_GPU] \
--master_addr [$WORKER_0_HOST] \
--node_rank [$ROLE_INDEX] \
--master_port [$WORKER_0_PORT] \
--nnodes [$WORKER_NUM] \
train.py \
--model_name_or_path "codellama/CodeLlama-34b-hf" \
--data_path "TIGER-Lab/MathInstruct" \
--bf16 True \
--output_dir checkpoints/MAmmoTH-Coder-34B \
--num_train_epochs 3 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--evaluation_strategy "no" \
--save_strategy "epoch" \
--save_total_limit 1 \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--deepspeed "ds_config/ds_config_zero3.json" \
--tf32 True
```
## Prompt Format
If you want to do CoT:
```
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:
```
If you want to do PoT:
```
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction} Let's write a program.
### Response:
```
## WebUI
We use [llama2-webui](https://github.com/liltom-eth/llama2-webui) as our ui bankend. To use webui for MammoTH run:
```
pip install gradio
cd webui/llama2-webui
python3 mammoth.py --model_path your_model_path --backend_type transformers
```
## **License**
Please check out the license of each subset in our curated dataset MathInstruct.
| Dataset Name | License Type |
|-------------- |---------------- |
| GSM8K | MIT |
| GSM8K-RFT | Non listed |
| AQuA-RAT | Apache 2.0 |
| MATH | MIT |
| TheoremQA | MIT |
| Camel-Math | Attribution-NonCommercial 4.0 International |
| NumGLUE | Apache-2.0 |
| CrowdSourced (Lila) | Attribution 4.0 International |
| MathQA | Apache-2.0 |
| Our Curated | MIT |
## **Citation**
Please cite our paper if you use our data, model or code. Please also kindly cite the original dataset papers.
```
@article{yue2023mammoth,
title={MAmmoTH: Building Math Generalist Models through Hybrid Instruction Tuning},
author={Xiang Yue, Xingwei Qu, Ge Zhang, Yao Fu, Wenhao Huang, Huan Sun, Yu Su, Wenhu Chen},
journal={arXiv preprint arXiv:2309.05653},
year={2023}
}
```

View File

View File

@ -0,0 +1,40 @@
{
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false,
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"reduce_bucket_size": "auto",
"contiguous_gradients": true,
"sub_group_size": 1e9,
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": 1e-8,
"weight_decay": "auto"
}
},
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
}
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 334 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 540 KiB

View File

@ -0,0 +1,39 @@
import json
import sys
from utils import number_it, compare_two_numbers
assert len(sys.argv) >= 2, 'you need to feed in a file'
def compare(answer, groundtruth):
groundtruth_str, groundtruth_num = groundtruth
if answer == groundtruth_str:
return True
else:
if groundtruth_num is not None and number_it(answer) is not None:
if compare_two_numbers(number_it(answer), groundtruth_num):
return True
else:
return False
else:
return False
for filename in sys.argv[1:]:
correct, wrong = 0, 0
with open(filename) as f:
for line in f:
entry = json.loads(line)
groundtruth = entry['correct'] if 'correct' in entry else entry['Answer']
if isinstance(groundtruth, list):
if compare(entry['pred'], groundtruth):
correct += 1
else:
wrong += 1
else:
if entry['pred'] == groundtruth:
correct += 1
else:
wrong += 1
print(filename, 'length=', correct + wrong, 'accuracy=', correct / (correct + wrong + 0.0001))

View File

@ -0,0 +1,217 @@
import json
from utils import delete_extra_zero,_strip_string
from statistics import mean
import re
import glob
import random
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{query}\n\n### Response:"
),
}
def find_math_answer(s):
assert('boxed' in s)
# s = s.replace(",", "")
ans = s.split('boxed')[-1]
if(ans[0] == '{'):
stack = 1
a = ''
for c in ans[1:]:
if(c == '{'):
stack += 1
a += c
elif(c == '}'):
stack -= 1
if(stack == 0): break
a += c
else:
a += c
else:
a = ans.split('$')[0].strip()
a=_strip_string(a)
return a
def extract_math_answer(pred_str):
if('The answer is ' in pred_str):
pred = pred_str.split('The answer is ')[-1].strip()
elif('the answer is ' in pred_str):
pred = pred_str.split('the answer is ')[-1].strip()
elif 'boxed' in pred_str:
ans = pred_str.split('boxed')[-1]
if (ans[0] == '{'):
stack = 1
a = ''
for c in ans[1:]:
if (c == '{'):
stack += 1
a += c
elif (c == '}'):
stack -= 1
if (stack == 0): break
a += c
else:
a += c
else:
a = ans.split('$')[0].strip()
a = _strip_string(a)
pred=a
else:
pattern = '-?\d*\.?\d+'
pred = re.findall(pattern, pred_str)
if(len(pred) >= 1):
pred = pred[-1]
else:
pred = ''
if pred != "":
if pred[-1] == ".":
pred = pred[:-1]
if pred[-1] == "/":
pred = pred[:-1]
pred=_strip_string(pred)
if 'boxed' in pred:
ans = pred.split('boxed')[-1]
if (ans[0] == '{'):
stack = 1
a = ''
for c in ans[1:]:
if (c == '{'):
stack += 1
a += c
elif (c == '}'):
stack -= 1
if (stack == 0): break
a += c
else:
a += c
else:
a = ans.split('$')[0].strip()
a = _strip_string(a)
pred=a
return pred
def data_reader(dataset: str):
questions = []
answers = []
decoder = json.JSONDecoder()
if dataset == "aqua":
with open('dataset/AQuA/AQuA.json') as f:
lines = f.readlines()
for line in lines:
json_res = decoder.raw_decode(line)[0]
choice = "(" + "(".join(json_res["options"])
choice = choice.replace("(", " (").replace(")", ") ")
choice = "Answer Choices:" + choice
questions.append(json_res["question"].strip() + "\n" + choice)
answers.append(json_res["correct"])
elif dataset == 'math':
with open('dataset/math/MATH.json', 'r') as f:
loaded = json.load(f)
# random.shuffle(loaded)
# random.shuffle(loaded)
for d in loaded:
questions.append(d['question'])
answers.append(d['answer'])
elif dataset == "gsm8k":
with open('dataset/gsm8k/gsm8k.jsonl') as f:
lines = f.readlines()
for line in lines:
json_res = decoder.raw_decode(line)[0]
questions.append(json_res["question"].strip())
answers.append(delete_extra_zero(json_res["answer"].split("#### ")[-1].replace(",", "")))
elif dataset == "svamp":
with open('dataset/SVAMP/SVAMP.json') as f:
json_data = json.load(f)
for line in json_data:
q = line["Body"].strip() + " " + line["Question"].strip()
a = str(line["Answer"])
if a[-2:] == ".0":
a = a[:-2]
questions.append(q)
answers.append(delete_extra_zero(a))
elif 'mmlu' in dataset:
with open(f'dataset/mmlu/{dataset.split("_")[1]}.json') as f:
json_data = json.load(f)
for line in json_data:
options = f'(A) {line["choices"][0]} (B) {line["choices"][1]} (C) {line["choices"][2]} (D) {line["choices"][3]}'
q = line["question"] + '\n' + 'Answer Choices: ' + options
a = ['A', 'B', 'C', 'D'][line['answer']]
questions.append(q)
answers.append(a)
elif dataset in ['numglue', 'simuleq', 'deepmind', 'sat']:
with open(f'dataset/{dataset}/{dataset}.json') as f:
json_data = json.load(f)
for line in json_data:
assert isinstance(line['question'], str) and isinstance(line['question'], str), line
questions.append(line['question'])
answers.append(str(line['answer']))
else:
raise ValueError("dataset is not properly defined ...")
q_len_list = []
for q in questions:
q_len_list.append(len(q.split(" ")))
q_len_mean = mean(q_len_list)
print("dataset : {}".format(dataset))
print("data size : {}".format(len(answers)))
print("average num of words for each sample : {}".format(q_len_mean))
return questions, answers
class BatchDatasetLoader:
def __init__(self, dataset: str, batch_size: int):
self.inputs, self.outputs = data_reader(dataset)
self.index = 0
self.batch_size = batch_size
self.length = len(self.inputs)
print(self.length, self.batch_size)
def __len__(self):
if self.batch_size == -1:
return 1
else:
return self.length // self.batch_size
def __getitem__(self, index):
if self.batch_size == -1:
if index >= self.__len__():
raise StopIteration
else:
return self.inputs, self.outputs
else:
if self.length % self.batch_size == 0:
if index >= self.__len__():
raise StopIteration
else:
tmp_inputs, tmp_outputs = [], []
for i in range(index * self.batch_size, min((index + 1) * self.batch_size, self.length)):
tmp_inputs.append(self.inputs[i])
tmp_outputs.append(self.outputs[i])
return tmp_inputs, tmp_outputs
else:
if index > self.__len__():
raise StopIteration
else:
tmp_inputs, tmp_outputs = [], []
for i in range(index * self.batch_size, min((index + 1) * self.batch_size, self.length)):
tmp_inputs.append(self.inputs[i])
tmp_outputs.append(self.outputs[i])
return tmp_inputs, tmp_outputs

View File

@ -0,0 +1 @@
All the output files will be saved here.

View File

@ -0,0 +1,501 @@
def get_prompt(qas: list, form: str):
if form == 'alpaca':
prompt_no_input, prefix = get_alpaca_format_prompt_wo_input(qas)
elif form == 'alpaca_mc':
prompt_no_input, prefix = get_alpaca_format_mc_prompt_wo_input(qas)
elif form == 'vicuna':
prompt_no_input, prefix = get_vicuna_format_prompt(qas)
elif form == 'short':
prompt_no_input, prefix = get_short_format_prompt(qas)
elif form == 'step':
prompt_no_input, prefix = get_step_by_step(qas)
elif form == 'tulu':
prompt_no_input, prefix = get_tulu_format_prompt(qas)
elif form == 'guanaco':
prompt_no_input, prefix = get_Guanaco_format_prompt(qas)
elif form == 'llama2chat':
prompt_no_input, prefix = get_Guanaco_format_prompt(qas)
else:
raise NotImplementedError(form)
return prompt_no_input, prefix
def get_tulu_format_prompt(qas: list):
# tmp = (
# "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
# )
tmp = ""
for q, a in qas:
tmp += '<|user|>\n{query}\n <|assistant|>\nThe answer is: {response}\n'.format(query=q, response=a)
prefix = '<|user|>\n{query}\n<|assistant|>\nThe answer is: '
return tmp, prefix
def get_vicuna_format_prompt(qas: list):
tmp = (
"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
)
for q, a in qas:
tmp += '\n\n' + 'USER: {query} \n ASSISTANT: {response}\n'.format(query=q, response=a)
prefix = '\n' + 'USER: {query}\n ASSISTANT: '
return tmp, prefix
def get_Guanaco_format_prompt(qas: list):
tmp = (
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
)
for q, a in qas:
tmp += '\n\n' + '### Human: {query}\n### Assistant: {response}\n'.format(query=q, response=a)
prefix = '\n' + '### Human: {query}\n### Assistant:'
return tmp, prefix
def get_llama2_chat_format_prompt(qas: list):
tmp = (
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
)
for q, a in qas:
tmp += '\n\n' + '### Human: {query}\n### Assistant: {response}\n'.format(query=q, response=a)
prefix = '\n' + '### Human: {query}\n### Assistant:'
return tmp, prefix
def get_alpaca_format_prompt_wo_input(qas: list):
tmp = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n"
)
for q, a in qas:
tmp += '\n' + '### Instruction:\n{query}\n\n### Response: {response}\n'.format(query=q, response=a)
prefix = '\n' + '### Instruction:\n{query}\n\n### Response:'
return tmp, prefix
def get_alpaca_format_mc_prompt_wo_input(qas: list):
tmp = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n"
)
for q, a in qas:
tmp += '\n' + '### Instruction:\n{query}\n\n### Response: Let\'s solve the multi-choice question step by step.\n{response}\n'.format(query=q, response=a)
prefix = '\n' + '### Instruction:\n{query}\n\n### Response: Let\'s solve the multi-choice question step by step.\n'
return tmp, prefix
def get_step_by_step(qas: list):
tmp = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n"
)
for q, a in qas:
tmp += '\n' + '### Instruction:\n{query}\n\n### Response: Let\'s think step by step. {response}\n'.format(query=q, response=a)
prefix = '\n' + '### Instruction:\n{query}\n\n### Response: Let\'s think step by step.'
return tmp, prefix
def get_short_format_prompt(qas: list):
tmp = "You are supposed to answer a math question by showing the steps to derive the answer.\n\n"
for q, a in qas:
tmp += '\n' + 'Q: {query}\nA:{response}\n'.format(query=q, response=a)
# prefix = '\n' + 'Q:\n{query}\nA:'
prefix = '\n' + 'Q: {query}\nA:'
return tmp, prefix
def split_examples(examples: str):
qas = []
for ex in examples.split('\n\n'):
q, a = ex.split('\n')
qas.append((q, a))
return qas
def get_examples(name: str, num_shots: int, pot_flag: str):
if num_shots == 0:
return []
examples = {}
examples['aqua'] = [
(
"John found that the average of 15 numbers is 40. If 10 is added to each number then the mean of the numbers is?\nAnswer Choices: (A) 50 (B) 45 (C) 65 (D) 78 (E) 64",
"If 10 is added to each number, then the mean of the numbers also increases by 10. So the new mean would be 50. The answer is (A)."
),
(
"If a / b = 3/4 and 8a + 5b = 22,then find the value of a.\nAnswer Choices: (A) 1/2 (B) 3/2 (C) 5/2 (D) 4/2 (E) 7/2",
"a / b = 3/4, then b = 4a / 3. So 8a + 5(4a / 3) = 22. This simplifies to 8a + 20a / 3 = 22, which means 44a / 3 = 22. So a is equal to 3/2. The answer is (B)."
),
(
"A person is traveling at 20 km/hr and reached his destiny in 2.5 hr then find the distance?\nAnswer Choices: (A) 53 km (B) 55 km (C) 52 km (D) 60 km (E) 50 km",
"The distance that the person traveled would have been 20 km/hr * 2.5 hrs = 50 km. The answer is (E)."
),
(
"How many keystrokes are needed to type the numbers from 1 to 500?\nAnswer Choices: (A) 1156 (B) 1392 (C) 1480 (D) 1562 (E) 1788",
"There are 9 one-digit numbers from 1 to 9. There are 90 two-digit numbers from 10 to 99. There are 401 three-digit numbers from 100 to 500. 9 + 90(2) + 401(3) = 1392. The answer is (B)."
),
]
examples['sat'] = [
(
"If $\frac{x-1}{3}=k$ and $k=3$, what is the value of $x$ ? \nAnswer Choices: (A) 2 (B) 4 (C) 9 (D) 10",
"If k = 3, then x - 1 = 3 * 3, therfore, x - 1 = 9 and x = 10. The answer is D",
),
(
"For $i=\sqrt{-1}$, what is the sum $(7+3 i)+(-8+9 i)$ ? \nAnswer Choices: (A) $-1+12 i$ (B) $-1-6 i$ (C) $15+12 i$ (D) $15-6 i$ 3",
"For (7+3 i)+(-8+9 i), the real part is 7 + (-8) = -1, the imageinary part is 3 i + 9 i = 12 i. The answer is A",
),
(
"On Saturday afternoon, Armand sent $m$ text messages each hour for 5 hours, and Tyrone sent $p$ text messages each hour for 4 hours. Which of the following represents the total number of messages sent by Armand and Tyrone on Saturday afternoon?\nAnswer Choices: (A) $9 m p$ (B) $20 m p$ (C) $5 m+4 p$ (D) $4 m+5 p$",
"Armand texts m messages each hour for 5 hours, which leads to 5m messages. Tyrone texts p messages each hour for 4 hours, which leds to 4p messages. The total is 5m + 4p. The answer is C.",
),
(
"$$\begin{array}{r}3 x+4 y=-23 \\2 y-x=-19\end{array}$$What is the solution $(x, y)$ to the system of equations above?\nAnswer Choices: (A) $(-5,-2)$ (B) $(3,-8)$ (C) $(4,-6)$ (D) $(9,-6)$",
"By solving this equation, we found that x = 3 and y = -8. The answer is B.",
)
]
examples["mmlu_mathematics"] = [
(
"Simplify and write the result with a rational denominator: $$\sqrt{\sqrt[3]{\sqrt{\frac{1}{729}}}}$$\nAnswer Choices: (A) \frac{3\sqrt{3}}{3} (B) \frac{1}{3} (C) \sqrt{3} (D) \frac{\sqrt{3}}{3}",
"Factoring $729=3^6$ and combining the roots $\frac{1}{2}\frac{1}{3}\frac{1}{2}=\frac{1}{12}$, we get that $\sqrt{\sqrt[3]{\sqrt{\frac{1}{729}}}}=\left(\frac{1}{3^6}\right)^{\frac{1}{12}}=\frac{1}{3^{\frac{1}{2}}}=\frac{3}{\sqrt{3}}$. The answer is (D)."
),
(
"Five thousand dollars compounded annually at an $x\%$ interest rate takes six years to double. At the same interest rate, how many years will it take $\$300$ to grow to $\$9600$?\nAnswer Choices:(A) 12 (B) 1 (C) 30 (D) 5",
"To go from $\$300$ to $\$9600$, the value must go up by a factor of $9600/300=32=2^5$. Since at this interest rate it takes six years for it to double, it will take $5*6=30$ years to grow to $\$9600$. The answer is (C)."
),
(
"Ten students take a biology test and receive the following scores: 45, 55, 50, 70, 65, 80, 40, 90, 70, 85. What is the mean of the students test scores?\nAnswer Choices: (A) 55 (B) 60 (C) 62 (D) 65",
"There are 10 students and the sum of their scores is $45 + 55 + 50 + 70 + 65 + 80 + 40 + 90 + 70 + 85 = 650$, the mean is $650/10=65$. The answer is (D)."
),
(
"The variable $x$ varies directly as the square of $y$, and $y$ varies directly as the cube of $z$. If $x$ equals $-16$ when $z$ equals 2, what is the value of $x$ when $z$ equals $\frac{1}{2}$?\nAnswer Choices: (A) -1 (B) 16 (C) -\frac{1}{256} (D) \frac{1}{16}",
"We know that $x \propto y^2$ and $y \propto z^3$, so $x = k z^6$ for some constant $k$. Plugging in for $x=-16$ and $z=2$, the constant value is $k=\frac{x}{z^6}=\frac{-16}{64}=-\frac{1}{4}$. So, when $z=\frac{1}{2}$, the value of $x$ is $x=kz^6=-\frac{1}{4}\frac{1}{2^6}=-\frac{1}{256}$. The answer is (C).",
),
(
"Joe was in charge of lights for a dance. The red light blinks every two seconds, the yellow light every three seconds, and the blue light every five seconds. If we include the very beginning and very end of the dance, how many times during a seven minute dance will all the lights come on at the same time? (Assume that all three lights blink simultaneously at the very beginning of the dance.)\nAnswer Choices: (A) 3 (B) 15 (C) 6 (D) 5",
"The least common multiple of 2, 3 and 5 is 30, so during a 7 minute dance, all the three lights will come on at the same time $2*7+1=15$ times. The answer is (B)."
)
]
examples["mmlu_physics"] = [
(
"A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?\nAnswer Choices: (A) 10 W (B) 30 W (C) 60 W (D) 240 W",
"Rate of energy usage is known as power; in an dissipative electrical circuit, power is given by voltage times current. So in our case, the power is 120 V times 2 amps, or 240 W. The answer is (D).",
),
(
"A point charge, Q = +1 mC, is fixed at the origin. How much work is required to move a charge, Q = +8 µC, from the point (0, 4 meters) to the point (3 meters, 0)?\nAnswer Choices: (A) 3.5 J (B) 6.0 J (C) 22.5 J (D) 40 J",
"To calculate the work required to move a charge from one location to another in a fixed electric field, it is enough to calculate the potential difference between the two locations. Here, the potential only depends on the distance between the charges; its $k q_1 q_2 / r$, where $k$ is Coulombs constant. Plugging in values $q_1 = $ 1 mC, $q_2 = 8 \mu$ C, gives the answer as 5.992 J, which rounds to 6 J. The answer is (B).",
),
(
"Which of the following conditions will ensure that angular momentum is conserved? I. Conservation of linear momentum II. Zero net external force III. Zero net external torque.\nAnswer Choices: (A) I and II only (B) I and III only (C) II and III only (D) III only",
"Torque is defined as the change in angular momentum; if there is zero external torque, angular momentum is conserved. The answer is (D)."
),
(
"A photocell of work function ϕ = 2eV is connected to a resistor in series. Light of frequency f = 1 × 10^15 Hz hits a metal plate of the photocell. If the power of the light is P = 100 W, what is the current through the resistor?\nAnswer Choices: (A) 2:00 AM (B) 6:00 AM (C) 12:00 AM (D) 24 A",
"The only answer above which has units of current is D, 24 A. The answer is (D).",
),
(
"A pipe full of air is closed at one end. A standing wave is produced in the pipe, causing the pipe to sound a note. Which of the following is a correct statement about the waves properties at the closed end of the pipe?\nAnswer Choices: (A) The pressure is at a node, but the particle displacement is at an antinode. (B) The pressure is at an antinode, but the particle displacement is at a node. (C) The pressure and the particle displacement are both at nodes. (D) The pressure and the particle displacement are both at antinodes.",
"At the closed end of the pipe, the particles cannot have any net displacement because the pipe closure stops them. So the particle displacement is at a node. This closure also causes the pressure to be maximal, i.e. an antinode. The answer is (B)."
)
]
examples["mmlu_chemistry"] = [
(
"Which of the following is considered an acid anhydride?\nAnswer Choices: (A) HCl (B) H2SO3 (C) SO2 (D) Al(NO3)3",
"An acid anhydride is a compound that is derived by removing water from an acid. The chemical formula for water is H2O, which means that we need to determine which of these options, when combined with H2O, forms an acid. SO2, or Sulfur dioxide, when combined with H2O, makes H2SO4, or sulfuric acid. The answer is (C).",
),
(
"Which of the following is expected to be a polar molecule?\nAnswer Choices: (A) PCl4F (B) BF3 (C) CO2 (D) Si(CH3)4",
"A polar molecule is one that has a slightly positive charge on one end of the molecule and a slightly negative charge on the other end. Boron trifluoride (BF3) has Boron as the center atom and three fluorine atoms attached to it; it is trigonal planar and symmetric, so it is nonpolar. Carbon Dioxide (CO2) has Carbon as the central atom with double bonds to two Oxygen atoms - this is also symmetrical and therefore nonpolar. The same is the case for tetramethyl silane (SI(CH3)4), which is a Silicon atom surrounded by four methyl groups. The structure of PCL4F is that Phosphorus is the central atom, attached to four chlorines and one fluorine atom. This is asymmetrical, and therefore has a net dipole and is expected to be a polar molecule. The answer is (A)."
),
(
"From the solubility rules, which of the following is true?\nAnswer Choices: (A) All chlorides, bromides, and iodides are soluble (B) All sulfates are soluble (C) All hydroxides are soluble (D) All ammonium-containing compounds are soluble",
"The chlorides, bromides, and iodides of lead, silver, and mercury are not soluble in water. This rules out (A). The sulfates of lead, barium, and calcium are not soluble in water, which rules out (B). The hydroxides of any metal besides sodium, potassium, ammonium, calcium, and barium are insoluble. This rules out (C). Typically ammonium ions indicate a soluble ionic substance. The answer is (D).",
),
(
"A new compound is synthesized and found to be a monoprotic acid with a molar mass of 248 g/mol. When 0.0050 mol of this acid are dissolved in 0.500 L of water, the pH is measured as 3.89. What is the pKa of this acid?\nAnswer Choices: (A) 3.89 (B) 7.78 (C) 5.78 (D) 2.33",
"Recall that $[A] = [H^{+}]$. Here, this is equal to $$10^{-3.89}$. Then we have $K_{a} = $frac{[H^{+}][A^{-}]}{[HA]} = \frac{10^{-3.89} \cdot 10^{-3.89}}{10^{-2}}. The resulting exponent is $-3.89 + (-3.89) - (-2) = 5.78$, therefore $K_a = 10^{-5.78}$. The $pK_a$ is the negative log of $K_a$, which is equal to $5.78$. The answer is (C)."
),
(
"A solution contains 2.00 mole of acetic acid, CH3COOH, and 1.00 mole of calcium acetate, Ca(CH3COO)2. The solution is able to resist the addition of a small amount of strong acid or strong base with only minor changes in the pH of the solution. Larger quantities of strong acid or strong base can cause a significant change in pH. How many moles of nitric acid, HNO3, may be added before the pH begins to change significantly?\nAnswer Choices: (A) 0.500 mole (B) 1.00 mole (C) 2.00 mole (D) 3.00 mole",
"We would like to compute the buffer capacity of this solution. First we write the equation for the ionization of the weak acid, in this case of acetic acid. $CH_{3}COOH (aq) + H_{2}O \rightarrow H_{3}O^{+} + CH3COO^{-}$. The conjugate base is therefore the acetate ion. The added strong acid, Nitric acid, will react with the conjugate base. Therefore the maximum amount of acid that can be added will be equal to the amount of acetate ion, or 2 moles. The answer is (C)."
)
]
examples["mmlu_biology"] = [
(
"In animal cells, which of the following represents the most likely pathway that a secretory protein takes as it is synthesized in a cell?\nAnswer Choices: (A) Plasma membraneGolgi apparatusribosomesecretory vesiclerough ER (B) RibosomeGolgi apparatusrough ERsecretory vesicleplasma membrane (C) Plasma membraneGolgi apparatusribosomesecretory vesiclerough ER (D) Ribosomerough ERGolgi apparatussecretory vesicleplasma membrane",
"Protein synthesis starts at the ribosome, so we can eliminate (A) and (C). The ribosome is often in the endoplasmic reticulum and moves from there to the Golgi apparatus, where it is modified and packaged into a vesicle. The vesicle then floats to the plasma membrane and is secreted. The answer is (D).",
),
(
"A mutation in a bacterial enzyme changed a previously polar amino acid into a nonpolar amino acid. This amino acid was located at a site distant from the enzymes active site. How might this mutation alter the enzymes substrate specificity?\nAnswer Choices: (A) By changing the enzymes pH optimum (B) By changing the enzymes location in the cell (C) By changing the shape of the protein (D) An amino acid change away from the active site cannot alter the enzymes substrate specificity.",
"A change in an amino acid leads to a change in the primary structure of the protein. A change in the primary structure may lead to a change in the secondary and the tertiary structure of the protein. A change in the tertiary structure means a change in the shape of the protein, so (C) has to be correct. Since the change does not affect the active site of the enzyme, we do not expect the activity of the enzyme to be affected. The answer is (C)."
),
(
"Which of the following is not a way to form recombinant DNA?\nAnswer Choices: (A) Translation (B) Conjugation (C) Specialized transduction (D) Transformation",
"The introduction of foreign DNA or RNA into bacteria or eukaryotic cells is a common technique in molecular biology and scientific research. There are multiple ways foreign DNA can be introduced into cells including transformation, transduction, conjugation, and transfection. In contrast, (A) is not a way to form DNA: during translation the ribosomes synthesize proteins from RNA. The answer is (A)."
),
(
"Homologous structures are often cited as evidence for the process of natural selection. All of the following are examples of homologous structures EXCEPT\nAnswer Choices: (A) the wings of a bird and the wings of a bat (B) the flippers of a whale and the arms of a man (C) the pectoral fins of a porpoise and the flippers of a seal (D) the forelegs of an insect and the forelimbs of a dog",
"Homologous structures are similar physical features in organisms that share a common ancestor but different functions. Comparisons (B) and (C) are clearly homologous because they share a common ancestor and the structures serve different purposes. Bat wings and birg wings are also homologous, while they are both wings, the forelimbs serve different purposes. Insects and dogs are very far ancestors since one is vertebrate while the other is invertebrate and the forelimbs serve the same purpose, so they are not homologous. The answer is (D)."
),
(
"Which of the following is not known to be involved in the control of cell division?\nAnswer Choices: (A) Cyclins (B) Protein kinases (C) Checkpoints (D) Fibroblast cells",
"Normal cells move through the cell cycle in a regulated way. At the checkpoint stage, they use information about their own internal state and cues from the environment around them to decide whether to proceed with cell division. Cues like these act by changing the activity of core cell cycle regulators inside the cell. The most common regulators are cyclins and cyclin-dependent kinases. Fibroblast cells do not play any role in cell division. The answer is (D)."
)
]
examples['gsm8k']= [
(
'There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?',
'There are 15 trees originally. Then there were 21 trees after the Grove workers planted some more. So there must have been 21 - 15 = <<21-15=6>>6 trees that were planted.\n#### 6'
),
(
'If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?',
'There are originally 3 cars. Then 2 more cars arrive. Now 3 + 2 = <<3+2=5>>5 cars are in the parking lot.\n#### 5'
),
(
'Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?',
'Originally, Leah had 32 chocolates and her sister had 42. So in total they had 32 + 42 = <<32+42=74>>74. After eating 35, they had 74 - 35 = <<74-35=39>>39 pieces left in total.\n#### 39'
),
(
'Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?',
'Jason had 20 lollipops originally. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = <<20-12=8>>8 lollipops.\n#### 8'
),
(
'Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?',
'Shawn started with 5 toys. He then got 2 toys each from his mom and dad. So he got 2 * 2 = <<2*2=4>>4 more toys. Now he has 5 + 4 = <<5+4=9>>9 toys.\n#### 9'
),
(
'There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?',
'There were originally 9 computers. For each day from monday to thursday, 5 more computers were installed. So 4 * 5 = <<4*5=20>>20 computers were added. Now 9 + 20 = <<9+20=29>>29 computers are now in the server room.\n#### 29'
),
(
'Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?',
'Michael started with 58 golf balls. He lost 23 on Tuesday, and lost 2 more on wednesday. So he had 58 - 23 = <<58-23=35>>35 at the end of Tuesday, and 35 - 2 = <<35-2=33>>33 at the end of wednesday.\n#### 33'
),
(
'Olivia has $23. She bought five bagels for $3 each. How much money does she have left?',
'Olivia had 23 dollars. She bought 5 bagels for 3 dollars each. So she spent 5 * 3 = <<5*3=15>>15 dollars. Now she has 23 - 15 = <<23-15=8>>8 dollars left.\n#### 8'
)
]
examples['gsm8k_pot']= [
(
'There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?',
'original_trees = 15\nfinal_trees = 21\nplanted_trees = final_trees - original_trees\nprint(planted_trees)\n'
),
(
'If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?',
'original_cars = 3\narriving_cars = 2\ntotal_cars = original_cars + arriving_cars\nprint(total_cars)\n'
),
(
'Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?',
'leah_chocolates = 32\nsister_chocolates = 42\neaten_chocolates = 35\nremaining_chocolates = leah_chocolates + sister_chocolates - eaten_chocolates\nprint(remaining_chocolates) \n'
),
(
'Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?',
'jason_original_lollipops = 20\njason_remaining_lollipops = 12\ngiven_lollipops = jason_original_lollipops - jason_remaining_lollipops\nprint(given_lollipops)\n'
),
(
'Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?',
'original_toys = 5 \ntoys_from_mom = 2 \ntoys_from_dad = 2 \ntotal_toys = original_toys + toys_from_mom + toys_from_dad \nprint(total_toys)\n'
),
(
'There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?',
'original_computers = 9\ncomputers_installed_per_day = 5\ndays = 4\ntotal_computers = original_computers + computers_installed_per_day * days\nprint(total_computers)\n'
),
(
'Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?',
'original_golf_balls = 58\nlost_on_tuesday = 23\nlost_on_wednesday = 2\nremaining_golf_balls = original_golf_balls - lost_on_tuesday - lost_on_wednesday\nprint(remaining_golf_balls)\n'
),
(
'Olivia has $23. She bought five bagels for $3 each. How much money does she have left?',
'olivia_money = 23\nbagel_cost = 3\nbagels_bought = 5\nremaining_money = olivia_money - bagel_cost * bagels_bought\nprint(remaining_money)\n'
)
]
examples['svamp'] = [
(
'children were riding on the bus. At the bus stop 82 children got on the bus while some got off the bus. Then there were 30 children altogether on the bus. How many more children got on the bus than those that got off?',
'Let\'s assume there are x students getting on the bus and y students getting off the bus, than 28 + x - y = 30. Therefore, x - y = 30 - 28 = 2. The answer is 2.'),
(
'Mary is baking a cake. The recipe calls for 11 cups of flour and 7 cups of sugar. She already put in some cups of flour. If she still needs 2 more cups of flour than sugar, How many cups of flour did she put in?',
'Let\'s assume there are x cups of clour already added, so we know 11 - x is the remaining cups of flour. Since the reminaing cups of flour is 2 more than sugar. Then we have 11 - x - 2 = 7. Therefore, x = 11 - 2 - 7 = 2. The answer is 2.',
),
(
'Frank put 11 pieces of candy in each bag. If he had 22 pieces of candy. How many bags would he have?',
'Let\'s assume there are x bags. Then we know 11 * x = 22. Therefore, x = 22 / 11 = 2. The answer is 2.'
),
(
'A farmer had 90 tomatoes in his garden. If he picked 154 of them yesterday and 50 today. How many tomatoes did he pick in all?',
'The number of tomatoes picked is x = 154 + 50. Therefore, x = 204. The answer is 204.'
),
(
'The grasshopper, the frog and the mouse had a jumping contest. The grasshopper jumped 19 inches. The frog jumped 10 inches farther than the grasshopper and the mouse jumped 20 inches farther than the frog. How much farther did the mouse jump than the grasshopper?',
'frog jumps 19 + 10 = 29 inches. The mouse jumps 29 + 20 = 49 inches. Thefore, the mouse jumps 49 - 19 = 30 inches farther than grasshopper. The answer is 30.',
),
(
'Allan brought 3 balloons and Jake brought 5 balloons to the park. Allan then bought 2 more balloons at the park. How many balloons did Allan and Jake have in the park?',
'Allan has 3 + 2 = 5 ballons. Jake as 5 ballons. Therefore, the total ballons is 5 + 5 = 10. The answer is 10.',
),
(
'Jake has 7 fewer peaches than Steven and 9 more peaches than Jill. Steven has 16 peaches. How many peaches does Jake have?',
'Let\'s assume Jake has x peaches. x + 7 = 16. Therefore, x = 16 - 7 = 9. The answer is 9.'
),
(
'Katie had 57 new games and 39 old games. Her friends had 34 new games. How many more games does Katie have than her friends?',
'Katie has a total of 57 + 39 = 96 games. Therefore, Katie has 96 - 34 = 62 games than her friend. The answer is 62.'
)
]
examples['svamp_pot'] = [
(
'children were riding on the bus. At the bus stop 82 children got on the bus while some got off the bus. Then there were 30 children altogether on the bus. How many more children got on the bus than those that got off?',
'children_after = 30\nchildren_joined = 82\nchildren_initial = children_after - children_joined\nchildren_left = children_initial - children_after\nmore_children_joined = children_joined - children_left\nprint(more_children_joined)\n'
),
(
'Mary is baking a cake. The recipe calls for 11 cups of flour and 7 cups of sugar. She already put in some cups of flour. If she still needs 2 more cups of flour than sugar, How many cups of flour did she put in?',
'needed_flour = 11\nneeded_sugar = 7\nremaining_flour_diff = 2\nalready_put_flour = needed_flour - (needed_sugar + remaining_flour_diff)\nprint(already_put_flour)\n'
),
(
'Frank put 11 pieces of candy in each bag. If he had 22 pieces of candy. How many bags would he have?',
'frank_candies = 22\ncandies_per_bag = 11\nbags_needed = frank_candies / candies_per_bag\nprint(bags_needed)\n'
),
(
'A farmer had 90 tomatoes in his garden. If he picked 154 of them yesterday and 50 today. How many tomatoes did he pick in all?',
'picked_yesterday = 154\npicked_today = 50\ntotal_picked = picked_yesterday + picked_today\nprint(total_picked)\n'
),
(
'The grasshopper, the frog and the mouse had a jumping contest. The grasshopper jumped 19 inches. The frog jumped 10 inches farther than the grasshopper and the mouse jumped 20 inches farther than the frog. How much farther did the mouse jump than the grasshopper?',
'grasshopper_jump = 19\nfrog_jump = grasshopper_jump + 10\nmouse_jump = frog_jump + 20\nmouse_jump_diff = mouse_jump - grasshopper_jump\nprint(mouse_jump_diff)\n'
),
(
'Allan brought 3 balloons and Jake brought 5 balloons to the park. Allan then bought 2 more balloons at the park. How many balloons did Allan and Jake have in the park?',
'allan_initial_balloons = 3\njake_balloons = 5\nallan_bought = 2\ntotal_balloons = allan_initial_balloons + jake_balloons + allan_bought\nprint(total_balloons)\n'
),
(
'Jake has 7 fewer peaches than Steven and 9 more peaches than Jill. Steven has 16 peaches. How many peaches does Jake have?',
'steven_peaches = 16\njake_peaches_diff = -7\njake_peaches = steven_peaches + jake_peaches_diff\nprint(jake_peaches)\n'
),
(
'Katie had 57 new games and 39 old games. Her friends had 34 new games. How many more games does Katie have than her friends?',
'katie_new_games = 57\nkatie_old_games = 39\nfriends_new_games = 34\ntotal_katie_games = katie_new_games + katie_old_games\ndifference_games = total_katie_games - friends_new_games\nprint(difference_games)\n'
)
]
examples['math'] = [
(
"The sum of two numbers is 6. The difference of their squares is 12. What is the positive difference of the two numbers?",
"""Let's think step by step
Call the two numbers $x$ and $y$.
We are given that $x+y = 6$ and $x^2 - y^2 = 12$.
Because $x^2 - y^2$ factors into $(x+y)(x-y)$,
we can substitute in for $x+y$,
giving $6(x-y) = 12$,
or $x-y = \boxed{2}$.
The answer is 2"""
),
(
"Which integer is closest to the cube root of 100?",
"""Let's think step by step
Either 4 or 5 is closest to $\sqrt[3]{100}$,
since $4^3=64$ and $5^3=125$.
Since $4.5^3=91.125<100$,
$\sqrt[3]{100}$ is closer to $\boxed{5}$ than to 4.
The answer is 5"""
),
(
"What is the value of $(x - y)(x + y)$ if $x = 10$ and $y = 15$?",
"""Let's think step by step
$(x-y)(x+y)=(10-15)(10+15) = (-5)(25) = \boxed{-125}$.
The answer is -125"""
),
(
"If $g(x) = 3x + 7$ and $f(x) = 5x - 9$, what is the value of $f(g(8))$?",
"""Let's think step by step
$g(8)=3(8)+7=24+7=31$.
Thus, $f(g(8))=f(31)=5(31)-9=155-9=\boxed{146}$.
The answer is 146"""),
(
"What is the greatest possible positive integer value of $x$ if $\displaystyle\frac{x^4}{x^2} < 10$?",
"""Let's think step by step
On the left-hand side, $x^2$ cancels, reducing the inequality to $x^2<10$.
Since $3^2=9<10$ while $4^2=16>10$, the greatest possible value of $x$ is $\boxed{3}$.
The answer is 3"""
),
(
"A scale drawing of a park shows that one inch represents 800 feet. A line segment in the drawing that is 4.75 inches long represents how many feet?",
"""Let's think step by step
Each inch of the 4.75-inch line segment represents 800 feet,
so the whole line segment represents $4.75\times800=\frac{19}{4}\cdot800=19\cdot200=\boxed{3800}$ feet.
The answer is 3800"""
),
(
"In Mr. Abraham's class, $10$ of the $15$ students received an $A$ on the latest exam. If the same ratio of students received an $A$ on Mrs. Berkeley's latest exam, and if Mrs. Berkeley has $24$ students total, how many students in Mrs. Berkeley's class received an $A$?",
"""Let's think step by step
If $10$ of $15$ students received an $A$,
then the ratio of students receiving an $A$ to students not receiving an $A$ is $\frac{10}{15}$, or $\frac{2}{3}$.
Let $x$ be the number of students in Mrs. Berkeley's class who received an $A$.
Since the ratio is consistent across the two classes, $\frac{2}{3} = \frac{x}{24}$.
Cross-multiplying yields $x = \frac{24\cdot 2}{3}$, so, by simplification, we can see that $\boxed{16}$ of Mrs. Berkeley's students must have received an $A$.
The answer is 16"""
),
(
"Find the value of the first term in the geometric sequence $a,b,c,32,64$.",
"""Let's think step by step
The common ratio is $\frac{64}{32} = 2$.
Therefore, the first term is $\frac{32}{2^3} = \frac{32}{8} = \boxed{4}$.
The answer is 4""")]
examples['math_pot'] = [
(
"The sum of two numbers is 6. The difference of their squares is 12. What is the positive difference of the two numbers?",
'sum_of_nums = 6\nsquare_diff = 12\nhalf_sum = sum_of_nums / 2\nprod_of_nums = (half_sum**2 - square_diff / 4)\nx = (half_sum + (prod_of_nums**0.5)) / 2\ny = sum_of_nums - x\npositive_diff = abs(x - y)\nprint(positive_diff)\n'
),
(
"Which integer is closest to the cube root of 100?",
'closest_cube_root = round(100 ** (1/3))\nprint(closest_cube_root)\n'
),
(
"What is the value of $(x - y)(x + y)$ if $x = 10$ and $y = 15$?",
'x = 10\ny = 15\nvalue = (x - y) * (x + y)\nprint(value)\n'
),
(
"If $g(x) = 3x + 7$ and $f(x) = 5x - 9$, what is the value of $f(g(8))$?",
'x_val = 8\ng_of_x = 3 * x_val + 7\nf_of_g_of_x = 5 * g_of_x - 9\nprint(f_of_g_of_x)\n'
),
(
"What is the greatest possible positive integer value of $x$ if $\displaystyle\frac{x^4}{x^2} < 10$?",
'upper_limit = (10)**0.5 # because x^4 / x^2 < 10 simplifies to x^2 < 10\nx_max = int(upper_limit)\nprint(x_max)\n'
),
(
"A scale drawing of a park shows that one inch represents 800 feet. A line segment in the drawing that is 4.75 inches long represents how many feet?",
'scale_factor = 800\ndrawing_length = 4.75\nreal_length = scale_factor * drawing_length\nprint(real_length)\n'
),
(
"In Mr. Abraham's class, $10$ of the $15$ students received an $A$ on the latest exam. If the same ratio of students received an $A$ on Mrs. Berkeley's latest exam, and if Mrs. Berkeley has $24$ students total, how many students in Mrs. Berkeley's class received an $A$?",
'abraham_ratio = 10 / 15\nberkeley_students = 24\nberkeley_A_students = int(abraham_ratio * berkeley_students)\nprint(berkeley_A_students)\n'
),
(
"Find the value of the first term in the geometric sequence $a,b,c,32,64$.",
'fifth_term = 64\nr = fifth_term / 32\nfirst_term = fifth_term / (r**4)\nprint(first_term)\n'
)]
examples['numglue'] = examples['svamp'][:6] + examples['aqua'][:2]
examples['simuleq'] = examples['svamp']
examples['deepmind'] = examples['svamp']
if 'pot' in pot_flag:
name += '_pot'
examples['numglue_pot'] = examples['svamp_pot'][:6] + examples['aqua'][:2]
examples['simuleq_pot'] = examples['svamp_pot']
examples['deepmind_pot'] = examples['svamp_pot']
print(examples[name])
return examples[name][:num_shots]

View File

@ -0,0 +1,147 @@
# Load model directly
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
from tqdm import tqdm
import argparse
import utils
from prompt_utils import *
from data_loader import BatchDatasetLoader
parser = argparse.ArgumentParser()
parser.add_argument("--model", default='', type=str)
parser.add_argument("--output", default='', type=str)
parser.add_argument("--shots", default=0, type=int)
parser.add_argument("--dataset", required=True,
choices=['aqua', 'sat', 'mmlu_mathematics',
'mmlu_physics', 'mmlu_chemistry', 'mmlu_biology'],
type=str)
parser.add_argument("--dtype", default='bfloat16', type=str)
parser.add_argument("--load_8bit", action='store_true', default=False)
parser.add_argument("--stem_flan_type", default='', choices=['', 'pot_prompt'], type=str)
parser.add_argument("--batch_size", default=8, type=int)
parser.add_argument("--print", action='store_true', default=False)
parser.add_argument("--form", default='alpaca_mc', type=str)
parser.add_argument("--model_max_length", default=1024, type=int)
parser.add_argument("--cot_backup", action='store_true', default=False)
args = parser.parse_args()
DTYPES = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}
def run_question_answer(questions: list, groundtruths: list, collect_rerun: bool = False):
used_examples = get_examples(args.dataset, args.shots, args.stem_flan_type)
outputs = utils.get_answer(
examples=used_examples,
questions=questions,
model=model,
tokenizer=tokenizer,
form=args.form,
max_length=args.model_max_length)
# We need to collect the values and possibly the rerun questions;
returned_value = []
rerun_questions = []
rerun_groundtruths = []
for output, question, groundtruth in zip(outputs, questions, groundtruths):
if 'print(' in output:
output = output.split("### Instruction")[0]
tmp_exec = utils.execute_with_timeout(output)
tmp = 'The answer is' + ' ' + tmp_exec
answer = utils.answer_clean(args.dataset, ('####', 'The answer is'), tmp)
# we rerun when exec with failure
if not tmp_exec and collect_rerun:
rerun_questions.append(utils.remove_flan_tag(question, args.stem_flan_type))
# print('Adding back', rerun_questions[-1])
rerun_groundtruths.append(groundtruth)
continue
else:
answer = utils.answer_clean(args.dataset, ('####', 'The answer is'), output)
returned_value.append((question, output, answer, groundtruth))
if collect_rerun:
assert len(returned_value) + len(rerun_questions) == len(questions) == len(groundtruths)
return returned_value, rerun_questions, rerun_groundtruths
else:
return returned_value
if __name__ == "__main__":
# Load model directly
tokenizer = AutoTokenizer.from_pretrained(
args.model,
padding_side="left",
trust_remote_code=True)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map="auto",
load_in_8bit=args.load_8bit,
torch_dtype=DTYPES[args.dtype],
trust_remote_code=True)
model.eval()
correct, wrong = 0, 0
if not args.output:
suffix = 'PoT' if 'pot' in args.stem_flan_type.lower() else 'CoT'
filename = args.model.split('/')[-1].replace('-', '_') + '_' + args.dataset
filename += '_' + f'{args.shots}shots' + '_' + args.form
filename += f'_length{args.model_max_length}'
filename += '_' + f'bs{args.batch_size}' + '_' + suffix
args.output = f'outputs/{filename}.jsonl'
print('Writing the output to', args.output)
file_handle = open(args.output, 'w')
match_answer_count, pot, cot = 0, 0, 0
for questions, groundtruths in tqdm(BatchDatasetLoader(args.dataset, args.batch_size)):
processed_questions = utils.process_question_with_flan_tag(questions, args.stem_flan_type)
if args.stem_flan_type == 'pot_prompt' and args.cot_backup:
returned_values, rerun_questions, rerun_groundtruths = run_question_answer(processed_questions, groundtruths, collect_rerun=True)
pot += len(returned_values)
cot += len(rerun_questions)
if rerun_questions:
processed_questions = utils.process_question_with_flan_tag(rerun_questions, "")
tmp = run_question_answer(processed_questions, rerun_groundtruths, collect_rerun=False)
returned_values += tmp
else:
returned_values = run_question_answer(processed_questions, groundtruths, collect_rerun=False)
for question, output, answer, groundtruth in returned_values:
# If the answer is not an option at all.
if answer not in ['A', 'B', 'C', 'D', 'E']:
options = utils.recover_options(question, combined=True)
prompt = f'Please find the closest option to {answer[:100]}. The options are {options}'
tmp = utils.get_answer(
examples=[],
questions=[prompt],
model=model,
tokenizer=tokenizer,
form=args.form)[0]
answer = utils.answer_clean(args.dataset, ('####', 'The answer is'), tmp)
match_answer_count += 1
# Compare to get the accuracy
if answer == groundtruth:
correct += 1
else:
wrong += 1
if args.print:
print(answer, '#', groundtruth, '#', 'Answer Option Matches:', match_answer_count, 'CoT/PoT', f'{cot}/{pot}', '#', correct / (correct + wrong))
example = {
'question': question,
'correct': groundtruth,
'solution': output,
'pred': answer,
'task': args.dataset,
}
file_handle.write(json.dumps(example) + '\n')
print('final accuracy: ', correct / (correct + wrong), 'call answer matching: ', match_answer_count)
file_handle.close()

View File

@ -0,0 +1,158 @@
# Load model directly
import torch
from prompt_utils import get_prompt
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import argparse
import utils
from prompt_utils import *
from data_loader import BatchDatasetLoader
from tqdm import tqdm
from vllm import LLM, SamplingParams
parser = argparse.ArgumentParser()
parser.add_argument("--model", default='', type=str)
parser.add_argument("--output", default='', type=str)
parser.add_argument("--stem_flan_type", default='', choices=['', 'pot_prompt'], type=str)
parser.add_argument("--dtype", default='bfloat16', type=str)
parser.add_argument("--dataset", required=True, choices=[
'gsm8k', 'svamp', 'math', 'numglue', 'deepmind', 'simuleq'], type=str)
parser.add_argument("--use_vllm", action='store_true', default=False)
parser.add_argument("--load_8bit", action='store_true', default=False)
parser.add_argument("--form", default='alpaca', type=str)
parser.add_argument("--shots", default=0, type=int)
parser.add_argument("--batch_size", default=8, type=int)
parser.add_argument("--print", action='store_true', default=False)
parser.add_argument("--model_max_length", default=1024, type=int)
parser.add_argument("--cot_backup", action='store_true', default=False)
args = parser.parse_args()
DTYPES = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}
def run_question_answer(questions: list, groundtruths: list, collect_rerun: bool = False):
used_examples = get_examples(args.dataset, args.shots, args.stem_flan_type)
if args.use_vllm:
prompt_no_input, prefix = get_prompt(used_examples, args.form)
input_strs = [prompt_no_input + prefix.format(query=q) for q in questions]
outputs = llm.generate(input_strs, sampling_params)
outputs = [output.outputs[0].text for output in outputs]
else:
outputs = utils.get_answer(
examples=used_examples,
questions=questions,
model=model,
tokenizer=tokenizer,
form=args.form,
max_length=args.model_max_length)
# We need to collect the values and possibly the rerun questions;
returned_value = []
rerun_questions = []
rerun_groundtruths = []
for output, question, groundtruth in zip(outputs, questions, groundtruths):
if 'print(' in output:
output = output.split("### Instruction")[0]
tmp = utils.execute_with_timeout(output)
tmp = 'The answer is' + ' ' + tmp
answer = utils.answer_clean(args.dataset, ('####', 'The answer is'), tmp)
else:
answer = utils.answer_clean(args.dataset, ('####', 'The answer is'), output)
if answer == "" and collect_rerun:
rerun_questions.append(utils.remove_flan_tag(question, args.stem_flan_type))
# print('Adding back', rerun_questions[-1])
rerun_groundtruths.append(groundtruth)
continue
returned_value.append((question, output, answer, groundtruth))
if collect_rerun:
assert len(returned_value) + len(rerun_questions) == len(questions) == len(groundtruths)
return returned_value, rerun_questions, rerun_groundtruths
else:
return returned_value
if __name__ == "__main__":
if args.use_vllm:
stop_tokens = ["USER:", "USER", "ASSISTANT:", "ASSISTANT", "### Instruction:", "### Instruction", "Response:", "Response"]
sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=args.model_max_length, stop=stop_tokens)
llm = LLM(model=args.model, tensor_parallel_size=torch.cuda.device_count(), dtype=args.dtype, trust_remote_code=True)
args.batch_size = -1
print('Using VLLM, we do not need to set batch size!')
else:
tokenizer = AutoTokenizer.from_pretrained(
args.model,
padding_side="left",
trust_remote_code=True)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map="auto",
load_in_8bit=args.load_8bit,
torch_dtype=DTYPES[args.dtype],
trust_remote_code=True)
model.eval()
correct, wrong = 0, 0
if not args.output:
suffix = 'PoT' if 'pot' in args.stem_flan_type.lower() else 'CoT'
filename = args.model.strip('/').split('/')[-1].replace('-', '_') + '_' + args.dataset
filename += '_' + f'{args.shots}shots' + '_' + args.form
filename += f'_length{args.model_max_length}'
if args.cot_backup:
filename += '_CoTBackup'
filename += '_' + f'bs{args.batch_size}' + '_' + suffix
args.output = f'outputs/{filename}.jsonl'
print('Writing the output to', args.output)
file_handle = open(args.output, 'w')
for questions, groundtruths in tqdm(BatchDatasetLoader(args.dataset, args.batch_size)):
# First pass to use PoT
processed_questions = utils.process_question_with_flan_tag(questions, args.stem_flan_type)
if args.stem_flan_type == 'pot_prompt' and args.cot_backup:
# if there is hybrid decoding, we try pot fist and then cot
returned_values, rerun_questions, rerun_groundtruths = run_question_answer(processed_questions, groundtruths, collect_rerun=True)
if rerun_questions:
# if things are not working well
processed_questions = utils.process_question_with_flan_tag(rerun_questions, "")
tmp = run_question_answer(processed_questions, rerun_groundtruths, collect_rerun=False)
returned_values += tmp
else:
# only cot_prompt or pot_prompt, then we don't need to rerun
returned_values = run_question_answer(processed_questions, groundtruths, collect_rerun=False)
for question, output, answer, groundtruth in returned_values:
if args.dataset == 'math':
assert len(groundtruth) == 2, groundtruth
groundtruth_str, groundtruth_num = groundtruth
if utils.compare_both_string_and_number_format(answer, groundtruth_str, groundtruth_num):
correct += 1
else:
wrong += 1
else:
if answer == groundtruth:
correct += 1
else:
wrong += 1
if args.print:
print(answer, '#', groundtruth, '#', correct / (correct + wrong))
example = {
'question': question,
'correct': groundtruth,
'solution': output,
'pred': answer,
'task': args.dataset
}
file_handle.write(json.dumps(example) + '\n')
print('finished one epoch')
print('final accuracy: ', correct / (correct + wrong))
file_handle.close()

View File

@ -0,0 +1,120 @@
# Load model directly
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import argparse
import utils
from prompt_utils import *
from data_loader import BatchDatasetLoader
from tqdm import tqdm
from collections import Counter
parser = argparse.ArgumentParser()
parser.add_argument("--model", default='', type=str)
parser.add_argument("--output", default='', type=str)
parser.add_argument("--stem_flan_type", default='', choices=['', 'pot_prompt'], type=str)
parser.add_argument("--dtype", default='bfloat16', type=str)
parser.add_argument("--dataset", required=True, choices=['gsm8k', 'svamp', 'math', 'numglue', 'deepmind', 'simuleq'], type=str)
parser.add_argument("--load_8bit", action='store_true', default=False)
parser.add_argument("--form", default='alpaca', type=str)
parser.add_argument("--shots", default=0, type=int)
parser.add_argument("--batch_size", default=8, type=int)
parser.add_argument("--num_samples", default=10, type=int)
parser.add_argument("--print", action='store_true', default=False)
parser.add_argument("--model_max_length", default=1024, type=int)
args = parser.parse_args()
DTYPES = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}
def run_question_answer_ensemble(questions: list, groundtruths: list, num_samples: int):
used_examples = get_examples(args.dataset, args.shots, args.stem_flan_type)
outputs = utils.get_ensemble_answer(
examples=used_examples,
questions=questions,
model=model,
tokenizer=tokenizer,
form=args.form,
num_samples=num_samples,
max_length=args.model_max_length)
# We need to collect the values and possibly the rerun questions;
returned_value = []
for i in range(len(questions)):
question, groundtruth = questions[i], groundtruths[i]
cur_answers = Counter()
for output in outputs[i * num_samples: (i + 1) * num_samples]:
if 'print(' in output:
output = output.split("### Instruction")[0]
tmp = utils.execute_with_timeout(output)
tmp = 'The answer is' + ' ' + tmp
answer = utils.answer_clean(args.dataset, ('####', 'The answer is'), tmp)
else:
answer = utils.answer_clean(args.dataset, ('####', 'The answer is'), output)
cur_answers.update([answer])
answer = list(cur_answers.most_common())[0][0]
returned_value.append((question, outputs, answer, groundtruth))
return returned_value
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(
args.model,
padding_side="left",
trust_remote_code=True)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map="auto",
load_in_8bit=args.load_8bit,
torch_dtype=DTYPES[args.dtype],
trust_remote_code=True)
model.eval()
correct, wrong = 0, 0
if not args.output:
suffix = 'PoT' if 'pot' in args.stem_flan_type.lower() else 'CoT'
filename = args.model.split('/')[-1].replace('-', '_') + '_' + args.dataset
filename += '_' + f'{args.shots}shots' + '_' + args.form
filename += f'_length{args.model_max_length}'
filename += '_' + f'bs{args.batch_size}' + '_' + f'sc{args.num_samples}' + '_' + suffix
args.output = f'outputs/{filename}.jsonl'
print('Writing the output to', args.output)
file_handle = open(args.output, 'w')
for questions, groundtruths in tqdm(BatchDatasetLoader(args.dataset, args.batch_size)):
# First pass to use PoT
processed_questions = utils.process_question_with_flan_tag(questions, args.stem_flan_type)
returned_values = run_question_answer_ensemble(processed_questions, groundtruths, args.num_samples)
for question, output, answer, groundtruth in returned_values:
if args.dataset == 'math':
assert len(groundtruth) == 2, groundtruth
groundtruth_str, groundtruth_num = groundtruth
if utils.compare_both_string_and_number_format(answer, groundtruth_str, groundtruth_num):
correct += 1
else:
wrong += 1
else:
if answer == groundtruth:
correct += 1
else:
wrong += 1
if args.print:
print(answer, '#', groundtruth, '#', correct / (correct + wrong))
example = {
'question': question,
'correct': groundtruth,
'solution': output,
'pred': answer,
'task': args.dataset
}
file_handle.write(json.dumps(example) + '\n')
print('final accuracy: ', correct / (correct + wrong))
file_handle.close()

View File

@ -0,0 +1,593 @@
import json
import re
from prompt_utils import get_prompt
from transformers import GenerationConfig
from io import StringIO
from contextlib import redirect_stdout
import math
import multiprocessing
import threading
from functools import lru_cache
import os
import torch
def format_code(code_str: str):
code = 'def run_it():\n'
for line in code_str.split('\n'):
code += ' ' + line + '\n'
code += 'run_it()'
return code
class CodeExecutor:
def __init__(self, code, timeout, use_process: bool):
self.code = format_code(code)
self.timeout = timeout
self.error = ''
self.use_process = use_process
def execute_code(self, return_val):
try:
f = StringIO()
with redirect_stdout(f):
exec(self.code, globals(), locals())
s = f.getvalue()
s = s.strip('\n')
return_val['result'] = s
except Exception:
pass
@staticmethod
def execute_code_with_string(code, index, return_val):
code = format_code(code)
try:
f = StringIO()
with redirect_stdout(f):
exec(code, globals(), locals())
s = f.getvalue()
s = s.strip('\n')
return_val[index] = s
except Exception as e:
# print(e)
pass
def run(self):
if self.use_process:
manager = multiprocessing.Manager()
return_dict = manager.dict()
process = multiprocessing.Process(
target=self.execute_code, args=(return_dict,))
process.start()
process.join(timeout=self.timeout)
process.terminate()
else:
return_dict = {}
thread = threading.Thread(
target=self.execute_code, args=(return_dict,))
thread.start()
thread.join(timeout=self.timeout)
if thread.is_alive():
thread.join() # Ensures the thread is terminated before continuing
print('time out!')
self.error = 'Execution timed out'
if 'result' in return_dict:
return return_dict['result']
else:
return ''
def read_jsonl(path: str):
with open(path, "r", encoding='utf-8') as fh:
return [json.loads(line) for line in fh.readlines() if line]
def extract_nums(s):
s = s.replace(",", "")
nums = re.findall(r"[+-]? *(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?", s)
return_list = []
for i in range(len(nums)):
try:
return_list.append(eval(nums[i].strip().lstrip(" 0")))
except:
pass
return return_list
def find_formula(step):
assert step.count("<<") == step.count(">>") == 1
left, right = step.find("<<")+2, step.find(">>")
return step[left: right]
def extract_answer(completion):
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
match = ANS_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return match_str
else:
assert False
def delete_extra_zero(n):
'''删除小数点后多余的0'''
try:
n=float(n)
except:
print("None {}".format(n))
return n
if isinstance(n, int):
return str(n)
if isinstance(n, float):
n = str(n).rstrip('0') # 删除小数点后多余的0
n = int(n.rstrip('.')) if n.endswith('.') else float(n) # 只剩小数点直接转int否则转回float
n=str(n)
return n
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr:
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except:
return string
def _remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
# assert len(splits) == 2
return splits[0]
else:
return string
def _fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split and split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def _strip_string(string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = _remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = _fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
# if string == "0.5":
# string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def extract_math_answer(pred_str):
if('The answer is ' in pred_str):
pred = pred_str.split('The answer is ')[-1].strip()
elif('the answer is ' in pred_str):
pred = pred_str.split('the answer is ')[-1].strip()
elif 'boxed' in pred_str:
ans = pred_str.split('boxed')[-1]
if not ans:
return ""
if (ans[0] == '{'):
stack = 1
a = ''
for c in ans[1:]:
if (c == '{'):
stack += 1
a += c
elif (c == '}'):
stack -= 1
if (stack == 0): break
a += c
else:
a += c
else:
a = ans.split('$')[0].strip()
a = _strip_string(a)
pred=a
else:
pattern = '-?\d*\.?\d+'
pred = re.findall(pattern, pred_str)
if(len(pred) >= 1):
pred = pred[-1]
else: pred = ''
if pred != "":
if pred[-1] == ".":
pred = pred[:-1]
if pred != "" and pred[-1] == "/":
pred = pred[:-1]
pred=_strip_string(pred)
if 'boxed' in pred:
ans = pred.split('boxed')[-1]
if not ans:
return ""
if (ans[0] == '{'):
stack = 1
a = ''
for c in ans[1:]:
if (c == '{'):
stack += 1
a += c
elif (c == '}'):
stack -= 1
if (stack == 0): break
a += c
else:
a += c
else:
a = ans.split('$')[0].strip()
a = _strip_string(a)
pred=a
return pred
def answer_clean(dataset: str, direct_answer_trigger_for_fewshot: tuple, pred: str):
if dataset == "math":
if len(pred) > 0:
pred_final=extract_math_answer(pred)
return pred_final
else:
return pred
# Determine if this is ICL, if so, use \n\n to split the first chunk.
ICL = False
for trigger in direct_answer_trigger_for_fewshot:
if pred.count(trigger) > 1:
ICL = True
if ICL:
pred = pred.split('\n\n')[0]
# Split the trigger to find the answer.
preds = re.split('|'.join(direct_answer_trigger_for_fewshot), pred)
answer_flag = True if len(preds) > 1 else False
pred = preds[-1]
if '=' in pred:
pred = pred.split('=')[-1].strip()
if dataset in ("aqua", "sat", "mmlu_mathematics", "mmlu_physics", "mmlu_chemistry", "mmlu_biology"):
tmp = re.findall(r'\b(A|B|C|D|E)\b', pred.upper())
if tmp:
pred = tmp
else:
pred = [pred.strip().strip('.')]
elif dataset in ("gsm8k", "svamp", "deepmind", "simuleq"):
pred = pred.replace(",", "")
pred = [delete_extra_zero(s.replace(",", "")) for s in re.findall(r'-?\d+/?\.?\d*', pred)]
elif dataset in ("numglue",):
tmp = re.findall(r'\b(A|B|C|D|E)\b', pred.upper())
if tmp:
pred = tmp
else:
pred = pred.replace(",", "")
pred = [delete_extra_zero(s.replace(",", "")) for s in re.findall(r'-?\d+/?\.?\d*', pred)]
else:
raise ValueError("dataset is not properly defined ...")
# If there is no candidate in list, null is set.
if len(pred) == 0:
pred = ""
else:
if answer_flag:
# choose the first element in list ...
pred = pred[0]
else:
# choose the last e
pred = pred[-1]
# (For arithmetic tasks) if a word ends with period, it will be omitted ...
if pred != "":
if pred[-1] == ".":
pred = pred[:-1]
if pred[-1] == "/":
pred = pred[:-1]
return pred
def get_answer(examples, questions, model, tokenizer, form,
max_length: int = 300, do_sample: bool = False):
prompt_no_input, prefix = get_prompt(examples, form=form)
# Formulate the real prompt
input_strs = [prompt_no_input + prefix.format(query=q) for q in questions]
batch = tokenizer(
input_strs,
padding=True,
return_tensors="pt",
)
with torch.no_grad():
output_ids = model.generate(
batch.input_ids.to(model.device),
attention_mask=batch.attention_mask.to(model.device),
pad_token_id=tokenizer.pad_token_id,
generation_config=GenerationConfig(
do_sample=do_sample,
max_new_tokens=max_length,
trust_remote_code=True)
)
output_strs = []
for output_id in output_ids.tolist():
tmp = tokenizer.decode(output_id[batch.input_ids.shape[-1]:], skip_special_tokens=True)
output_strs.append(tmp)
return output_strs
def get_ensemble_answer(examples, questions, model, tokenizer, form, num_samples: int, max_length: int = 300):
prompt_no_input, prefix = get_prompt(examples, form=form)
# Formulate the real prompt
input_strs = [prompt_no_input + prefix.format(query=q) for q in questions]
batch = tokenizer(
input_strs,
padding=True,
return_tensors="pt",
)
with torch.no_grad():
output_ids = model.generate(
batch.input_ids.to(model.device),
attention_mask=batch.attention_mask.to(model.device),
pad_token_id=tokenizer.pad_token_id,
generation_config=GenerationConfig(
do_sample=True,
max_new_tokens=max_length,
trust_remote_code=True,
num_return_sequences=num_samples,
temperature=0.7)
)
output_strs = []
for output_id in output_ids.tolist():
tmp = tokenizer.decode(output_id[batch.input_ids.shape[-1]:], skip_special_tokens=True)
output_strs.append(tmp)
return output_strs
def execute_with_timeout(code: str, timeout: int=5, use_process: bool = True):
executor = CodeExecutor(code, timeout, use_process)
s = executor.run()
return s
def within_eps(pred: float, gt: float):
eps = abs(gt) * 0.04
if pred >= gt - eps and pred <= gt + eps:
return True
else:
return False
def floatify(num: str):
try:
num = float(num)
if num.is_integer():
return round(num)
else:
return num
except Exception:
return None
def number_it(num: str):
if 'frac' in num:
pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}"
num = re.sub(pattern, r"\1/\2", num)
try:
num = str(eval(num))
except Exception:
pass
elif ',' in num:
num = num.replace(',', '')
if floatify(num) is not None:
return floatify(num)
else:
try:
num = eval(num)
if isinstance(num, list) or isinstance(num, tuple):
num = num[0]
if floatify(num) is not None:
return floatify(num)
else:
return None
except Exception:
return None
def compare_two_numbers(p, gt):
try:
if math.isnan(p):
return False
if isinstance(gt, int):
return round(p) == gt
else:
return within_eps(pred=p, gt=gt)
except Exception:
return False
def get_decimal_with_wolfram(string: str) -> float:
import wolframalpha
wolfram_client = wolframalpha.Client('AU7JWQ-QQUV8K8QLQ')
for ex in wolfram_client.query(f'compute {string}').pods:
if ex['@title'] in ['Decimal approximation', 'Decimal form']:
for sub in ex.subpods:
try:
return float(sub['plaintext'][:20])
except Exception:
pass
for ex in wolfram_client.query(f'compute {string}').pods:
if ex['@title'] in ['Result']:
for sub in ex.subpods:
try:
return float(sub['plaintext'][:8])
except Exception:
pass
return None
@lru_cache(maxsize=None)
def compare_both_string_and_number_format(answer, groundtruth_str, groundtruth_num):
if answer == groundtruth_str:
return True
else:
if groundtruth_num is not None and number_it(answer) is not None:
if compare_two_numbers(number_it(answer), groundtruth_num):
return True
else:
return False
else:
return False
def process_question_with_flan_tag(questions: list, stem_flan_type: str):
if stem_flan_type == "pot_prompt":
prefix = " Let's write a program."
elif stem_flan_type == "":
prefix = ""
else:
prefix = " " + stem_flan_type
questions = [q + prefix for q in questions]
return questions
def remove_flan_tag(question: str, stem_flan_type: str):
if stem_flan_type == "pot_prompt":
question = question.replace(" Let's write a program.", "")
else:
question = question.replace(" " + stem_flan_type, "")
return question
def recover_options(input_str: str, combined: bool = False):
options = input_str.split('Answer Choices:')[-1].strip()
if 'Let\'s' in options:
options = options[:options.index('Let\'s')]
if combined:
return options
else:
index_1, index_2, index_3, index_4 = options.find('(A)'), options.find('(B)'), options.find('(C)'), options.find('(D)')
if '(E)' in options:
index5 = options.find('(E)')
opion_a = options[index_1+3:index_2].strip()
opion_b = options[index_2+3:index_3].strip()
opion_c = options[index_3+3:index_4].strip()
if '(E)' in options:
opion_d = options[index_4+3:index5].strip()
option_e = [options[index5+3:].strip()]
else:
opion_d = options[index_4+3:].strip()
option_e = []
return [opion_a, opion_b, opion_c, opion_d] + option_e

View File

@ -0,0 +1 @@
This folder contains all the outputs

View File

@ -0,0 +1,10 @@
numpy
fire
transformers>=4.34.0
torch==2.0.1
sentencepiece
tokenizers
wandb
accelerate
protobuf
vllm

View File

@ -0,0 +1,29 @@
export MODEL_PATH=[MODEL_PATH]
export OUTPUT_PATH="checkpoints/[OUTPUT_DIR]"
export MASTER_ADDR="localhost"
export WANDB_DISABLED=True
export NUM_GPUS=8
torchrun --master_addr ${MASTER_ADDR} \
--nproc_per_node=${NUM_GPUS} \
--master_port=6008 \
train.py \
--model_name_or_path $MODEL_PATH \
--data_path "TIGER-Lab/MathInstruct" \
--bf16 True \
--output_dir ${OUTPUT_PATH}\
--num_train_epochs 2 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 2000\
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--tf32 True

View File

@ -0,0 +1,36 @@
export MODEL_PATH='path_to_Llama-2-7b-hf'
export OUTPUT_PATH='llama2_mft_m0.2'
export MASTER_ADDR="localhost"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export NUM_GPUS=8
export mask_p=0.2
export warmup_steps=48000
export lr_decay_ratio=0.6
export min_lr_ratio=0.05
export decay=True
torchrun --master_addr ${MASTER_ADDR} \
--nproc_per_node=${NUM_GPUS} \
--master_port=6008 \
train.py \
--model_name_or_path $MODEL_PATH \
--data_path "../data/MathInstruct/MathInstruct.json" \
--bf16 True \
--output_dir ${OUTPUT_PATH}\
--num_train_epochs 5 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 4 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True \
--model_max_length 1024

View File

@ -0,0 +1,32 @@
export MODEL_PATH='path_to_Llama-2-7b-hf'
export OUTPUT_PATH='llama2_sft'
export MASTER_ADDR="localhost"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export NUM_GPUS=8
torchrun --master_addr ${MASTER_ADDR} \
--nproc_per_node=${NUM_GPUS} \
--master_port=6008 \
train.py \
--model_name_or_path $MODEL_PATH \
--data_path "../data/MathInstruct/MathInstruct.json" \
--bf16 True \
--output_dir ${OUTPUT_PATH}\
--num_train_epochs 3 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 4 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True \
--model_max_length 1024

View File

@ -0,0 +1,33 @@
export MODEL_PATH="mistralai/Mistral-7B-v0.1"
export OUTPUT_PATH="checkpoints/[OUTPUT_DIR]"
export MASTER_ADDR="localhost"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export NUM_GPUS=8
export mask_p=0.2
export warmup_steps=48000
export lr_decay_ratio=0.6
export min_lr_ratio=0.05
export decay=True
torchrun --master_addr ${MASTER_ADDR} --nproc_per_node=${NUM_GPUS} --master_port=6008 train.py \
--model_name_or_path $MODEL_PATH \
--data_path "MathInstruct/MathInstruct.json" \
--bf16 True \
--output_dir ${OUTPUT_PATH}\
--num_train_epochs 5 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 16 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 2000\
--save_total_limit 20 \
--learning_rate 5e-6 \
--weight_decay 0. \
--warmup_ratio 0.02 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'MistralDecoderLayer' \
--tf32 True \
--model_max_length 1024

View File

@ -0,0 +1,27 @@
export MODEL_PATH="mistralai/Mistral-7B-v0.1"
export OUTPUT_PATH="checkpoints/[OUTPUT_DIR]"
export MASTER_ADDR="localhost"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export NUM_GPUS=8
torchrun --master_addr ${MASTER_ADDR} --nproc_per_node=${NUM_GPUS} --master_port=6008 train.py \
--model_name_or_path $MODEL_PATH \
--data_path "MathInstruct/MathInstruct.json" \
--bf16 True \
--output_dir ${OUTPUT_PATH}\
--num_train_epochs 3 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 16 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 2000\
--save_total_limit 4 \
--learning_rate 5e-6 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'MistralDecoderLayer' \
--tf32 True

View File

@ -0,0 +1,445 @@
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence
import os
import json
import datasets
import torch
import transformers
import math
from torch.utils.data import Dataset
from transformers import Trainer
import pathlib
import utils
import random
lr_decay_ratio = os.environ.get('lr_decay_ratio')
min_lr_ratio = os.environ.get('min_lr_ratio')
drop_mask = os.environ.get('drop_mask')
mask_probability = os.environ.get('mask_p')
dec = os.environ.get('dec')
decay = os.environ.get('decay')
random_mask_p = os.environ.get('random_mask_p')
warmup_steps = os.environ.get('warmup_steps')
use_random = os.environ.get('use_random')
late_mask = os.environ.get('late_mask')
late_mask_steps = os.environ.get('late_mask_steps')
if lr_decay_ratio is None:
lr_decay_ratio = 1
else:
lr_decay_ratio = float(lr_decay_ratio)
if min_lr_ratio is None:
min_lr_ratio = 0
else:
min_lr_ratio = float(min_lr_ratio)
if drop_mask is None:
drop_mask = False
else:
drop_mask = True
if late_mask is None:
late_mask = False
else:
late_mask = True
if use_random is not None:
use_random = True
if random_mask_p is not None:
random_mask_p = True
else:
random_mask_p = False
if warmup_steps is None:
warmup_steps = 0
else:
warmup_steps = int(warmup_steps)
if late_mask_steps is None:
late_mask_steps = 0
else:
late_mask_steps = int(late_mask_steps)
if mask_probability is None:
mask_probability = 0
if dec is not None:
dec = True
else:
dec = False
if decay is not None:
decay = True
else:
decay = False
mask_probability = float(mask_probability)
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
padding_side: Optional[str] = field(default="right")
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
template_variation: bool = field(
default=True, metadata={"help": "whether to use template variation"})
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=2048,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
flash_attn: bool = field(default=False)
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def mask_target_tokens(input_ids, sources_tokenized, mask_probability, MASK_INDEX, tokenizer):
masked_input_ids = copy.deepcopy(input_ids)
vocab_size = int(tokenizer.vocab_size)
for input_id, source_len in zip(masked_input_ids, sources_tokenized["input_ids_lens"]):
for i in range(source_len, len(input_id)):
if random.random() < mask_probability:
if use_random:
input_id[i] = random.randint(0, vocab_size-1)
input_id[i] = MASK_INDEX
return masked_input_ids
def mask_target_tokens(input_ids, sources_tokenized, mask_probability, MASK_INDEX, tokenizer):
masked_input_ids = copy.deepcopy(input_ids)
for input_id, source_len in zip(masked_input_ids, sources_tokenized["input_ids_lens"]):
for i in range(source_len, len(input_id)):
if random.random() < mask_probability:
input_id[i] = MASK_INDEX
return masked_input_ids
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
steps
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
if mask_probability != 0:
if dec:
if warmup_steps > 0 and steps < warmup_steps:
_mask_probability = (warmup_steps - steps) / warmup_steps * mask_probability * 2
else:
_mask_probability = mask_probability
else:
if warmup_steps > 0 and steps < warmup_steps:
_mask_probability = steps / warmup_steps * mask_probability
else:
_mask_probability = mask_probability
if late_mask:
if steps < late_mask_steps:
_mask_probability = 0
elif warmup_steps > 0 and (steps - late_mask_steps) < warmup_steps:
_mask_probability = (steps-late_mask_steps) / warmup_steps * mask_probability
else:
_mask_probability = mask_probability
# if random_mask_p:
# _mask_probability = random.random() * mask_probability
# assert 1==0
MASK_INDEX = tokenizer.convert_tokens_to_ids(['<mask>'])[0]
# print(MASK_INDEX, tokenizer.pad_token_id)
# assert 1==0
masked_input_ids = mask_target_tokens(input_ids, sources_tokenized, _mask_probability, MASK_INDEX, tokenizer)
input_ids = masked_input_ids
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, template_variation: bool):
super(SupervisedDataset, self).__init__()
logging.warning("Loading data...")
if os.path.exists(data_path):
with open(data_path) as f:
list_data_dict = json.load(f)
else:
list_data_dict = datasets.load_dataset(data_path)["train"]
logging.warning("Formatting inputs...")
if template_variation:
PROMPT_DICT = random.choice(utils.PROMPT_TEMPLATE)
else:
PROMPT_DICT = utils.PROMPT_TEMPLATE_SINGLE
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = []
for example in list_data_dict:
if example.get("input", "") != "":
sources.append(prompt_input.format_map(example))
else:
sources.append(prompt_no_input.format_map(example))
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
self.sources = sources
self.targets = targets
def __len__(self):
return len(self.sources)
def naive__getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
def __getitem__(self, i):
return dict(input_ids=self.sources[i], labels=self.targets[i])
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
steps: int = 0
def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
if drop_mask:
MASK_INDEX = self.tokenizer.convert_tokens_to_ids(['<mask>'])[0]
attention_mask = attention_mask & input_ids.ne(MASK_INDEX)
else:
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
)
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
sources = []
targets = []
for instance in instances:
source = instance['input_ids']
target = instance['labels']
sources.append(source)
targets.append(target)
self.steps += 1
data_dict = preprocess(sources, targets, self.tokenizer, self.steps)
input_ids, labels = data_dict['input_ids'], data_dict['labels']
# input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path,
template_variation=data_args.template_variation)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
def train():
transformers.logging.set_verbosity_info()
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
print(training_args)
print('Start Loading Model')
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
if training_args.flash_attn:
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
use_cache=False,
).to('cuda')
else:
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
print('Start building tokenizer')
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
# padding_side="right",
padding_side=model_args.padding_side,
use_fast=False,
)
print("*"*50)
print("Before adding, tokenizer length: ",len(tokenizer))
special_tokens_dict = dict()
if mask_probability != 0:
tokenizer.add_tokens(["<mask>"])
def resize(model):
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
old_num_tokens = input_embeddings.shape[0]
num_new_tokens = len(tokenizer) - old_num_tokens
print('num_new_tokens', num_new_tokens)
if num_new_tokens != 0:
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
# model.resize_token_embeddings(len(auto_tokenizer), pad_to_multiple_of=8)
model.resize_token_embeddings(len(tokenizer))
model.get_input_embeddings().weight.data[-num_new_tokens:] = input_embeddings_avg
model.get_output_embeddings().weight.data[-num_new_tokens:] = output_embeddings_avg
resize(model)
if tokenizer.pad_token is None:
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
if tokenizer.eos_token is None:
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
if tokenizer.bos_token is None:
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
if tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
smart_tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens_dict,
tokenizer=tokenizer,
model=model,
)
print("*"*50)
print("After adding, tokenizer length: ",len(tokenizer))
print('Start building data module')
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
print('Start building the trainer module')
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
# lr_decay_ratio = train_args.lr_decay_ratio
# min_lr_ratio = train_args.min_lr_ratio
def _get_cosine_schedule_with_warmup_lr_lambda(
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
lr_decay_steps = num_training_steps * lr_decay_ratio
if current_step > lr_decay_steps:
if decay:
remaining_steps = num_training_steps - current_step
# return min_lr_ratio * (remaining_steps / max(1, remaining_steps))
return max(0, min_lr_ratio * (remaining_steps / max(1, num_training_steps * (1 - lr_decay_ratio))))
return min_lr_ratio
progress = float(current_step - num_warmup_steps) / float(max(1, lr_decay_steps - num_warmup_steps))
coefficient = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return min_lr_ratio + coefficient * (1.0 - min_lr_ratio)
def add_lr_decay_limit_for_cosine_schedule():
transformers.optimization._get_cosine_schedule_with_warmup_lr_lambda = _get_cosine_schedule_with_warmup_lr_lambda
add_lr_decay_limit_for_cosine_schedule()
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
if __name__ == "__main__":
train()

View File

@ -0,0 +1,194 @@
import dataclasses
import logging
import math
import os
import io
import sys
import time
import json
from typing import Optional, Sequence, Union, Dict
import tqdm
import copy
def _make_w_io_base(f, mode: str):
if not isinstance(f, io.IOBase):
f_dirname = os.path.dirname(f)
if f_dirname != "":
os.makedirs(f_dirname, exist_ok=True)
f = open(f, mode=mode)
return f
def _make_r_io_base(f, mode: str):
if not isinstance(f, io.IOBase):
f = open(f, mode=mode)
return f
def jdump(obj, f, mode="w", indent=4, default=str):
"""Dump a str or dictionary to a file in json format.
Args:
obj: An object to be written.
f: A string path to the location on disk.
mode: Mode for opening the file.
indent: Indent for storing json dictionaries.
default: A function to handle non-serializable entries; defaults to `str`.
"""
f = _make_w_io_base(f, mode)
if isinstance(obj, (dict, list)):
json.dump(obj, f, indent=indent, default=default)
elif isinstance(obj, str):
f.write(obj)
else:
raise ValueError(f"Unexpected type: {type(obj)}")
f.close()
def jload(f, mode="r"):
"""Load a .json file into a dictionary."""
f = _make_r_io_base(f, mode)
jdict = json.load(f)
f.close()
return jdict
PROMPT_TEMPLATE = [
{
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
},
{
"prompt_input": (
"You are supposed to follow an instruction, and then the input to generate proper response.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"You are supposed to follow an instruction to generate proper response."
"### Instruction:\n{instruction}\n\n### Response:"
),
},
{
"prompt_input": (
"Please follow the instruction and input to give a response.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Please follow the instruction to give a response."
"### Instruction:\n{instruction}\n\n### Response:"
),
},
{
"prompt_input": (
"You are an expert, please listen to human instruction and input to generate the response.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"You are an expert, please listen to human instruction to generate the response.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
},
{
"prompt_input": (
"Let's follow the instruction to respond to an input.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Let's follow the instruction to generate a response.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
},
{
"prompt_input": (
"The instruction is a description of the task. You need to follow that and respond to the paired input.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"The instruction is a description of the task. You need to follow that and respond.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
},
{
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"Instruction:\n{instruction}\n\nResponse:"
),
},
{
"prompt_input": (
"You are supposed to follow an instruction, and then the input to generate proper response.\n\n"
"#Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse:"
),
"prompt_no_input": (
"You are supposed to follow an instruction to generate proper response."
"Instruction:\n{instruction}\n\nResponse:"
),
},
{
"prompt_input": (
"Please follow the instruction and input to give a response.\n\n"
"Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse:"
),
"prompt_no_input": (
"Please follow the instruction to give a response."
"Instruction:\n{instruction}\n\nResponse:"
),
},
{
"prompt_input": (
"You are an expert, please listen to human instruction and input to generate the response.\n\n"
"Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse:"
),
"prompt_no_input": (
"You are an expert, please listen to human instruction to generate the response.\n\n"
"Instruction:\n{instruction}\n\nResponse:"
),
},
{
"prompt_input": (
"Let's follow the instruction to respond to an input.\n\n"
"Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse:"
),
"prompt_no_input": (
"Let's follow the instruction to generate a response.\n\n"
"Instruction:\n{instruction}\n\nResponse:"
),
},
{
"prompt_input": (
"The instruction is a description of the task. You need to follow that and respond to the paired input.\n\n"
"Instruction:\n{instruction}\n\nInput:\n{input}\n\nResponse:"
),
"prompt_no_input": (
"The instruction is a description of the task. You need to follow that and respond.\n\n"
"Instruction:\n{instruction}\n\nResponse:"
),
},
]
PROMPT_TEMPLATE_SINGLE = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}

View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,171 @@
# MetaMath: Bootstrap Your Own Mathematical Questions for Large Language Models
[![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](CODE_LICENSE)
[![Model Weight License](https://img.shields.io/badge/Model%20Weights%20License-LLaMA2-yellow)](MetaMath/LICENSE)
[![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/release/python-390/)
<p align="center">
🤗 <a href="https://huggingface.co/meta-math" target="_blank">HF Repo</a> • 📃 <a href="https://arxiv.org/abs/2309.12284" target="_blank">[MetaMath]</a><br>
</p>
<p align="center" width="100%">
<a ><img src="./imgs/metamath.svg" alt="MetaMath" style="width: 80%; min-width: 300px; display: block; margin: auto;"></a>
</p>
## News
- 🔥 Our **MetaMath-Llemma-7B** model achieves **30.0 pass@1** on the MATH Benchmarks, surpassing all the SOTA open-source LLM in 7B-13B scales! All the training scripts and the model are opened.
- 🔥 Our **MetaMath-Mistral-7B** model achieves **77.7 pass@1** on the [GSM8k Benchmarks](https://github.com/openai/grade-school-math), surpassing all the SOTA open-source LLM! All the training scripts and the model are opened.
- 🔥 The full **MetaMathQA** dataset is now released in the huggingface [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA/tree/main)!
- 🔥 We released the GSM8K_Backward dataset is also released in the huggingface [GSM8K_Backward](https://huggingface.co/datasets/meta-math/GSM8K_Backward) to evaluate the reversal mathematical reasoning ability!
- 🔥 Although the data augmentation for **MetaMathQA** is sourced from **ChatGPT 3.5**, Our **MetaMath-70B** model outperforms the closed-source LLMs **ChatGPT 3.5** on the GSM8K!
- 🔥 Our **MetaMath-7B** model achieves **66.5 pass@1** on the [GSM8k Benchmarks](https://github.com/openai/grade-school-math), **11.6** points higher than the SOTA open-source LLM!
- 🔥 Our **MetaMath-7B** model achieves **19.8 pass@1** on the [MATH Benchmarks](https://github.com/hendrycks/math), **9.1** points higher than the SOTA open-source LLM!
| Model | Checkpoint | Paper | GSM8k | MATH | License|
| ----- |------| ---- |------|-------| ----- |
| MetaMath-70B-V1.0 | 🤗 <a href="https://huggingface.co/meta-math/MetaMath-70B-V1.0" target="_blank">HF Link</a> | 📃 <a href="https://arxiv.org/abs/2309.12284" target="_blank">[MetaMath]</a>| **82.3** | **26.6** | <a href="https://ai.meta.com/resources/models-and-libraries/llama-downloads/" target="_blank">Llama 2 </a> |
| MetaMath-13B-V1.0 | 🤗 <a href="https://huggingface.co/meta-math/MetaMath-13B-V1.0" target="_blank">HF Link</a> | 📃 <a href="https://arxiv.org/abs/2309.12284" target="_blank">[MetaMath]</a>| **72.3** | **22.4** | <a href="https://ai.meta.com/resources/models-and-libraries/llama-downloads/" target="_blank">Llama 2 </a> |
| MetaMath-7B-V1.0 | 🤗 <a href="https://huggingface.co/meta-math/MetaMath-7B-V1.0" target="_blank">HF Link</a> | 📃 <a href="https://arxiv.org/abs/2309.12284" target="_blank">[MetaMath]</a>| **66.5** | **19.8** | <a href="https://ai.meta.com/resources/models-and-libraries/llama-downloads/" target="_blank">Llama 2 </a>|
| MetaMath-Mistral-7B | 🤗 <a href="https://huggingface.co/meta-math/MetaMath-Mistral-7B" target="_blank">HF Link</a> | 📃 <a href="https://arxiv.org/abs/2309.12284" target="_blank">[MetaMath]</a>| **77.7** | **28.2** | <a href="http://www.apache.org/licenses/" target="_blank">Apache License 2.0 </a>|
| MetaMath-Llemma-7B | 🤗 <a href="https://huggingface.co/meta-math/MetaMath-Llemma-7B" target="_blank">HF Link</a> | 📃 <a href="https://arxiv.org/abs/2309.12284" target="_blank">[MetaMath]</a>| **69.2** | **30.0** | <a href="http://www.apache.org/licenses/" target="_blank">Apache License 2.0 </a>|
## Comparing MetaMath with the LLM models.
🔥 Comprehensive Results
| Model | GSM8k Pass@1 | MATH Pass@1 |
|---------------------|--------------|-------------|
| MPT-7B | 6.8 | 3.0 |
| Falcon-7B | 6.8 | 2.3 |
| LLaMA-1-7B | 11.0 | 2.9 |
| LLaMA-2-7B | 14.6 | 2.5 |
| MPT-30B | 15.2 | 3.1 |
| LLaMA-1-13B | 17.8 | 3.9 |
| GPT-Neo-2.7B | 19.5 | -- |
| Falcon-40B | 19.6 | 2.5 |
| Baichuan-chat-13B | 23.9 | -- |
| Vicuna-v1.3-13B | 27.6 | -- |
| LLaMA-2-13B | 28.7 | 3.9 |
| InternLM-7B | 31.2 | -- |
| ChatGLM-2-6B | 32.4 | -- |
| GPT-J-6B | 34.9 | -- |
| LLaMA-1-33B | 35.6 | 3.9 |
| LLaMA-2-34B | 42.2 | 6.24 |
| RFT-7B | 50.3 | -- |
| LLaMA-1-65B | 50.9 | 10.6 |
| Qwen-7B | 51.6 | -- |
| WizardMath-7B | 54.9 | 10.7 |
| LLaMA-2-70B | 56.8 | 13.5 |
| WizardMath-13B | 63.9 | 14.0 |
| 🔥 MetaMath-7B | **66.5** | **19.8** |
| 🔥 MetaMath-13B | **72.3** | **22.4** |
| 🔥 MetaMath-Mistral-7B | **77.7** | **28.2** |
| 🔥 MetaMath-Llemma-7B | **69.2** | **30.0** |
| WizardMath-70B | 81.6 | 22.7 |
| 🔥 MetaMath-70B | **82.3** | **26.6** |
<h2 id="env">Quick Start</h2>
Clone Metamath and install the required packages:
```bash
git clone https://github.com/meta-math/MetaMath.git
cd MetaMath
pip install -r requirements.txt
```
If you encounter a Ray installation problem, please run:
```bash
pip install --upgrade ray
pip install --upgrade pyarrow
```
<h2 id="Inference">Dataset Usage</h2>
Run the following command to load the data:
```python
from datasets import load_dataset
dataset = load_dataset("meta-math/MetaMathQA")
```
<h2 id="train">Training</h2>
you need to prepare the llama-2 base model and our **MetaMathQA** dataset huggingface [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA/tree/main)
```
bash run.sh
```
or
```
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node=8 --use_env train_math.py \
--model_name_or_path "path/to/llama-2" \
--data_path "path/to/metamathqa" \
--data_length 10000000 \
--bf16 True \
--output_dir "path/to/save" \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 2 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True
```
### Supervised fine-tuning
We supervised fine-tune MetaMath-7B with the following hyperparameters:
| Hyperparameter | LLaMA 2 7B |
|----------------|-------------|
| Batch size | 128 |
| Learning rate | 2e-5 |
| Epochs | 3 |
| Max length | 512 |
| LR scheduler | cosine |
<h2 id="evaluation">Evaluation</h2>
we use the vllm to help the fast generation:
```
python eval_gsm8k.py --model "path/to/save" --data_file ./data/test/GSM8K_test.jsonl
python eval_math.py --model "path/to/save" --data_file ./data/test/MATH_test.jsonl
```
where the "path/to/save" should be replaced by the finetuned model, you can also download our series of MetaMath models in huggingface:
🤗 <a href="https://huggingface.co/meta-math/MetaMath-7B-V1.0" target="_blank">MetaMath 7B</a> 🤗 <a href="https://huggingface.co/meta-math/MetaMath-13B-V1.0" target="_blank">MetaMath 13B</a> 🤗 <a href="https://huggingface.co/meta-math/MetaMath-70B-V1.0" target="_blank">MetaMath 70B</a>
The inference prompt for our MetaMath is:
```
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
```
Thanks for the open source code of [WizardMath](https://github.com/nlpxucan/WizardLM/tree/main/WizardMath) and [RFT](https://github.com/OFA-Sys/gsm8k-ScRel/tree/main). Some of our codes are based on them.
<h2 id="citation">Citation</h2>
Please cite the paper if you refer to our model, code, data or paper from MetaMath.
```
@article{yu2023metamath,
title={MetaMath: Bootstrap Your Own Mathematical Questions for Large Language Models},
author={Yu, Longhui and Jiang, Weisen and Shi, Han and Yu, Jincheng and Liu, Zhengying and Zhang, Yu and Kwok, James T and Li, Zhenguo and Weller, Adrian and Liu, Weiyang},
journal={arXiv preprint arXiv:2309.12284},
year={2023}
}
```

View File

@ -0,0 +1,7 @@
# MetaMathQA Data
## Train Data
The full **MetaMathQA** dataset is now released in the huggingface [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA/tree/main)!
## Test Data
We released the GSM8K_Backward dataset is also released in the huggingface [GSM8K_Backward](https://huggingface.co/datasets/meta-math/GSM8K_Backward) to evaluate the reversal mathematical reasoning ability!

View File

@ -0,0 +1,3 @@
# MetaMathQA
The full **MetaMathQA** dataset is now released in the huggingface [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA/tree/main)

View File

@ -0,0 +1,134 @@
import argparse
import json
import re
import jsonlines
from fraction import Fraction
from vllm import LLM, SamplingParams
import sys
MAX_INT = sys.maxsize
def is_number(s):
try:
float(s)
return True
except ValueError:
pass
try:
import unicodedata
unicodedata.numeric(s)
return True
except (TypeError, ValueError):
pass
return False
def extract_answer_number(completion):
text = completion.split('The answer is: ')
if len(text) > 1:
extract_ans = text[-1].strip()
match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans)
if match:
if '/' in match.group():
denominator = match.group().split('/')[1]
numerator = match.group().split('/')[0]
if is_number(denominator) == True and is_number(numerator) == True:
if denominator == '0':
return round(float(numerator.replace(',', '')))
else:
frac = Fraction(match.group().replace(',', ''))
num_numerator = frac.numerator
num_denominator = frac.denominator
return round(float(num_numerator / num_denominator))
else:
return None
else:
if float(match.group().replace(',', '')) == float('inf'):
return None
return round(float(match.group().replace(',', '')))
else:
return None
else:
return None
def batch_data(data_list, batch_size=1):
n = len(data_list) // batch_size
batch_data = []
for i in range(n-1):
start = i * batch_size
end = (i+1)*batch_size
batch_data.append(data_list[start:end])
last_start = (n-1) * batch_size
last_end = MAX_INT
batch_data.append(data_list[last_start:last_end])
return batch_data
def gsm8k_test(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1):
INVALID_ANS = "[invalid]"
gsm8k_ins = []
gsm8k_answers = []
problem_prompt = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
)
print('promt =====', problem_prompt)
with open(data_path,"r+", encoding="utf8") as f:
for idx, item in enumerate(jsonlines.Reader(f)):
temp_instr = problem_prompt.format(instruction=item["query"])
gsm8k_ins.append(temp_instr)
temp_ans = item['response'].split('#### ')[1]
temp_ans = int(temp_ans.replace(',', ''))
gsm8k_answers.append(temp_ans)
gsm8k_ins = gsm8k_ins[start:end]
gsm8k_answers = gsm8k_answers[start:end]
print('lenght ====', len(gsm8k_ins))
batch_gsm8k_ins = batch_data(gsm8k_ins, batch_size=batch_size)
stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
sampling_params = SamplingParams(temperature=0.0, top_p=1, max_tokens=512, stop=stop_tokens)
print('sampleing =====', sampling_params)
llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
result = []
res_completions = []
for idx, (prompt, prompt_answer) in enumerate(zip(batch_gsm8k_ins, gsm8k_answers)):
if isinstance(prompt, list):
pass
else:
prompt = [prompt]
completions = llm.generate(prompt, sampling_params)
for output in completions:
prompt = output.prompt
generated_text = output.outputs[0].text
res_completions.append(generated_text)
invalid_outputs = []
for idx, (prompt, completion, prompt_answer) in enumerate(zip(gsm8k_ins, res_completions, gsm8k_answers)):
doc = {'question': prompt}
y_pred = extract_answer_number(completion)
if y_pred != None:
result.append(float(y_pred) == float(prompt_answer))
else:
result.append(False)
temp = {'question': prompt, 'output': completion, 'answer': prompt_answer}
invalid_outputs.append(temp)
acc = sum(result) / len(result)
# print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', invalid_outputs)
print('start===', start, ', end====', end)
print('gsm8k length====', len(result), ', gsm8k acc====', acc)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str) # model path
parser.add_argument("--data_file", type=str, default='') # data path
parser.add_argument("--start", type=int, default=0) #start index
parser.add_argument("--end", type=int, default=MAX_INT) # end index
parser.add_argument("--batch_size", type=int, default=400) # batch_size
parser.add_argument("--tensor_parallel_size", type=int, default=8) # tensor_parallel_size
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
gsm8k_test(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size)

View File

@ -0,0 +1,115 @@
import argparse
import json
import pdb
import jsonlines
import util
from vllm import LLM, SamplingParams
import sys
MAX_INT = sys.maxsize
INVALID_ANS = "[invalid]"
invalid_outputs = []
def remove_boxed(s):
left = "\\boxed{"
try:
assert s[:len(left)] == left
assert s[-1] == "}"
return s[len(left):-1]
except:
return None
def process_results(doc, completion, answer):
split_ans = completion.split('The answer is: ')
if len(split_ans) > 1:
ans = split_ans[-1]
extract_ans_temp = ans.split('.\n')[0]
extract_ans_temp = extract_ans_temp.strip()
if len(extract_ans_temp)>0 and extract_ans_temp[-1] == '.':
extract_ans = extract_ans_temp[0:-1]
else:
extract_ans = extract_ans_temp
extract_ans = extract_ans.strip()
if util.is_equiv(extract_ans, answer):
return True
else:
return False
else:
temp = {'question': doc, 'output': completion, 'answer': answer}
invalid_outputs.append(temp)
return False
def batch_data(data_list, batch_size=1):
n = len(data_list) // batch_size
batch_data = []
for i in range(n-1):
start = i * batch_size
end = (i+1)*batch_size
batch_data.append(data_list[start:end])
last_start = (n-1) * batch_size
last_end = MAX_INT
batch_data.append(data_list[last_start:last_end])
return batch_data
def test_hendrycks_math(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1):
hendrycks_math_ins = []
hendrycks_math_answers = []
problem_prompt = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
)
print('promt =====', problem_prompt)
with open(data_path, "r+", encoding="utf8") as f:
for idx, item in enumerate(jsonlines.Reader(f)):
temp_instr = problem_prompt.format(instruction=item["instruction"])
hendrycks_math_ins.append(temp_instr)
solution = item['output']
temp_ans = remove_boxed(util.last_boxed_only_string(solution))
hendrycks_math_answers.append(temp_ans)
print('total length ===', len(hendrycks_math_ins))
hendrycks_math_ins = hendrycks_math_ins[start:end]
hendrycks_math_answers = hendrycks_math_answers[start:end]
print('lenght ====', len(hendrycks_math_ins))
batch_hendrycks_math_ins = batch_data(hendrycks_math_ins, batch_size=batch_size)
stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=2048, stop=stop_tokens)
print('sampleing =====', sampling_params)
llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
res_completions = []
for idx, (prompt, prompt_answer) in enumerate(zip(batch_hendrycks_math_ins, hendrycks_math_answers)):
if isinstance(prompt, list):
pass
else:
prompt = [prompt]
completions = llm.generate(prompt, sampling_params)
for output in completions:
prompt_temp = output.prompt
generated_text = output.outputs[0].text
res_completions.append(generated_text)
results = []
for idx, (prompt, completion, prompt_answer) in enumerate(zip(hendrycks_math_ins, res_completions, hendrycks_math_answers)):
res = process_results(prompt, completion, prompt_answer)
results.append(res)
acc = sum(results) / len(results)
# print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', invalid_outputs)
print('start===', start, ', end====',end)
print('length====', len(results), ', acc====', acc)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default='') # model path
parser.add_argument("--data_file", type=str, default='') # data path
parser.add_argument("--start", type=int, default=0) #start index
parser.add_argument("--end", type=int, default=MAX_INT) # end index
parser.add_argument("--batch_size", type=int, default=400) # batch_size
parser.add_argument("--tensor_parallel_size", type=int, default=8) # tensor_parallel_size
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
test_hendrycks_math(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size)

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 317 KiB

View File

@ -0,0 +1,16 @@
transformers>=4.34.0
wandb==0.15.3
torch==2.0.1
sentencepiece==0.1.99
tokenizers==0.13.3
accelerate==0.21.0
bitsandbytes==0.40.0
vllm
fraction
tqdm
numpy
fire
openai
scipy
jsonlines
pandas

View File

@ -0,0 +1,39 @@
export MODEL_PATH='path_to_Llama-2-7b-hf'
export SAVE_PATH='metamath_llama_mft'
export MASTER_ADDR="localhost"
export MASTER_PORT="1231"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export WANDB_DISABLED=true
export HF_TOKEN="token of your huggingface"
export mask_p=0.2
export warmup_steps=75000
export lr_decay_ratio=0.6
export min_lr_ratio=0.05
export decay=True
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node=8 --use_env train_math.py \
--model_name_or_path $MODEL_PATH \
--data_path "MetaMathQA/MetaMathQA-395K.json" \
--data_length 10000000 \
--bf16 True \
--output_dir $SAVE_PATH \
--num_train_epochs 5 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 16 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 2000 \
--save_total_limit 20 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.02 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--model_max_length 1024 \
--tf32 True
python eval_gsm8k.py --model $SAVE_PATH --data_path ./data/test/GSM8K_test.jsonl
python eval_math.py --model $SAVE_PATH --data_path ./data/test/MATH_test.jsonl

View File

@ -0,0 +1,39 @@
export MODEL_PATH="mistralai/Mistral-7B-v0.1"
export SAVE_PATH='metamath_mistral_mft'
export MASTER_ADDR="localhost"
export MASTER_PORT="1231"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export WANDB_DISABLED=true
export HF_TOKEN="token of your huggingface"
export mask_p=0.2
export warmup_steps=75000
export lr_decay_ratio=0.6
export min_lr_ratio=0.05
export decay=True
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node=8 --use_env train_math.py \
--model_name_or_path $MODEL_PATH \
--data_path "MetaMathQA/MetaMathQA-395K.json" \
--data_length 10000000 \
--bf16 True \
--output_dir $SAVE_PATH \
--num_train_epochs 5 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 16 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 2000 \
--save_total_limit 20 \
--learning_rate 3e-6 \
--weight_decay 0. \
--warmup_ratio 0.02 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'MistralDecoderLayer' \
--model_max_length 1024 \
--tf32 True
python eval_gsm8k.py --model $SAVE_PATH --data_path ./data/test/GSM8K_test.jsonl
python eval_math.py --model $SAVE_PATH --data_path ./data/test/MATH_test.jsonl

View File

@ -0,0 +1,440 @@
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified by Zheng Yuan and Hongyi Yuan
import os
import copy
import logging
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence
import io
import torch
import transformers
from torch.utils.data import Dataset
from transformers import Trainer
import argparse
import json
import random
import numpy as np
import math
random.seed(42)
def setup_seed(seed=42):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark=False
torch.backends.cudnn.deterministic=True
seed = os.environ.get('seed')
if seed is not None:
seed = int(seed)
setup_seed(seed)
def _make_r_io_base(f, mode: str):
if not isinstance(f, io.IOBase):
f = open(f, mode=mode)
return f
def jload(f, mode="r"):
"""Load a .json file into a dictionary."""
f = _make_r_io_base(f, mode)
jdict = json.load(f)
f.close()
return jdict
lr_decay_ratio = os.environ.get('lr_decay_ratio')
min_lr_ratio = os.environ.get('min_lr_ratio')
decay = os.environ.get('decay')
mask_probability = os.environ.get('mask_p')
random_mask_p = os.environ.get('random_mask_p')
warmup_steps = os.environ.get('warmup_steps')
drop_mask = os.environ.get('drop_mask')
if decay is not None:
decay = True
else:
decay = False
if drop_mask is None:
drop_mask = False
else:
drop_mask = True
if random_mask_p is not None:
random_mask_p = True
else:
random_mask_p = False
if warmup_steps is None:
warmup_steps = 0
else:
warmup_steps = int(warmup_steps)
if mask_probability is None:
mask_probability = 0
mask_probability = float(mask_probability)
if lr_decay_ratio is None:
lr_decay_ratio = 1
else:
lr_decay_ratio = float(lr_decay_ratio)
if min_lr_ratio is None:
min_lr_ratio = 0
else:
min_lr_ratio = float(min_lr_ratio)
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
#### 28
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
overwrite_output_dir: bool = field(default=True)
flash_attn: bool = field(default=False)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def mask_target_tokens(input_ids, sources_tokenized, mask_probability, MASK_INDEX):
masked_input_ids = copy.deepcopy(input_ids)
for input_id, source_len in zip(masked_input_ids, sources_tokenized["input_ids_lens"]):
for i in range(source_len, len(input_id)):
if random.random() < mask_probability:
input_id[i] = MASK_INDEX
return masked_input_ids
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
steps
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
if mask_probability != 0:
if warmup_steps > 0 and steps < warmup_steps:
_mask_probability = steps / warmup_steps * mask_probability
else:
_mask_probability = mask_probability
if random_mask_p:
_mask_probability = random.random() * mask_probability
# assert 1==0
MASK_INDEX = tokenizer.convert_tokens_to_ids(['<mask>'])[0]
# print(MASK_INDEX, tokenizer.pad_token_id)
# assert 1==0
masked_input_ids = mask_target_tokens(input_ids, sources_tokenized, _mask_probability, MASK_INDEX)
input_ids = masked_input_ids
index = 0
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_args, tokenizer: transformers.PreTrainedTokenizer):
super(SupervisedDataset, self).__init__()
logging.warning("Loading data...")
data_path = data_args.data_path
try:
data_path = data_path_map[data_path]
except:
data_path = data_path
try:
list_data_dict = jload(data_path)
except BaseException:
with open(data_path, 'r') as f:
lines = f.readlines()
list_data_dict = [json.loads(line.strip()) for line in lines]
list_data_dict = random.sample(list_data_dict, len(list_data_dict))
list_data_dict = list_data_dict[:data_args.data_length]
# logging.warning("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
# print(list_data_dict[0])
if 'instruction' in list_data_dict[0]:
pass
else:
def get_input(query):
if query.find('\n') == -1:
return ''
return '\n'.join(query.split('\n')[1:])
list_data_dict = [{'instruction':data['query'].split('\n')[0], 'input':get_input(data['query']), 'output':data['response']} for data in list_data_dict]
# import ipdb; ipdb.set_trace()
sources = [
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
self.sources = sources
self.targets = targets
def __len__(self):
return len(self.sources)
def naive__getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
def __getitem__(self, i):
return dict(input_ids=self.sources[i], labels=self.targets[i])
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
steps: int = 0
def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
sources = []
targets = []
for instance in instances:
source = instance['input_ids']
target = instance['labels']
sources.append(source)
targets.append(target)
self.steps += 1
# print(self.steps, len(instances))
data_dict = preprocess(sources, targets, self.tokenizer, self.steps)
input_ids, labels = data_dict['input_ids'], data_dict['labels']
# input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
if drop_mask:
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
MASK_INDEX = self.tokenizer.convert_tokens_to_ids(['<mask>'])[0]
attention_mask = attention_mask & input_ids.ne(MASK_INDEX)
else:
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
)
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
data_args.data_length = int(remaining_args[1])
if training_args.flash_attn:
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
use_cache=False,
).to('cuda')
else:
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
if mask_probability != 0:
tokenizer.add_tokens(["<mask>"])
def resize(model):
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
old_num_tokens = input_embeddings.shape[0]
num_new_tokens = len(tokenizer) - old_num_tokens
print('num_new_tokens', num_new_tokens)
if num_new_tokens != 0:
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
model.resize_token_embeddings(len(tokenizer))
model.get_input_embeddings().weight.data[-num_new_tokens:] = input_embeddings_avg
model.get_output_embeddings().weight.data[-num_new_tokens:] = output_embeddings_avg
resize(model)
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model,
)
if "llama" in model_args.model_name_or_path:
tokenizer.add_special_tokens(
{
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
}
)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
def _get_cosine_schedule_with_warmup_lr_lambda(
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
lr_decay_steps = num_training_steps * lr_decay_ratio
if current_step > lr_decay_steps:
if decay:
remaining_steps = num_training_steps - current_step
# return min_lr_ratio * (remaining_steps / max(1, remaining_steps))
return max(0, min_lr_ratio * (remaining_steps / max(1, num_training_steps * (1 - lr_decay_ratio))))
return min_lr_ratio
progress = float(current_step - num_warmup_steps) / float(max(1, lr_decay_steps - num_warmup_steps))
coefficient = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return min_lr_ratio + coefficient * (1.0 - min_lr_ratio)
def add_lr_decay_limit_for_cosine_schedule():
transformers.optimization._get_cosine_schedule_with_warmup_lr_lambda = _get_cosine_schedule_with_warmup_lr_lambda
add_lr_decay_limit_for_cosine_schedule()
trainer.train()
trainer.save_state()
# if os.environ.get('LOCAL_RANK') == '0':
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
if __name__ == "__main__":
train()

View File

@ -0,0 +1,248 @@
import pprint
def last_boxed_only(sample):
q, a = sample
a = last_boxed_only_string(a)
if a == None:
return None
return (q, a)
def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx == None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
return retval
def only_until_first_boxed_from_tokens(string, tokens):
idx = string.find("\\boxed")
if idx < 0:
idx = string.find("\\fbox")
if idx < 0:
return None
cum_length = 0
for i, t in enumerate(tokens):
cum_length += len(t)
if cum_length >= idx:
break
return tokens[:i]
def clean_numbers(sample):
if not sample:
return None
new_sample = list()
for s in sample:
new_sample.append(_clean_numbers(s))
return tuple(new_sample)
def _clean_numbers(string):
"""
Clean Numbers in the given string
>>> _clean_numbers(None, "Hello 123")
'Hello 123'
>>> _clean_numbers(None, "Hello 1234")
'Hello 1,234'
>>> _clean_numbers(None, "Hello 1234324asdasd")
'Hello 1,234,324asdasd'
"""
num_prev_digits = 0
new_string = ""
for i, c in enumerate(string):
# isdigit() doesnt work here because of weird unicode chars.
if c in {'1', '2', '3', '4', '5', '6', '7', '8', '9', '0'}:
num_prev_digits += 1
else:
if num_prev_digits > 3:
# Some fixing
string_number = new_string[-num_prev_digits:]
new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number))
num_prev_digits = 0
new_string += c
if num_prev_digits > 3:
# Some fixing
string_number = new_string[-num_prev_digits:]
new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number))
return new_string
def fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except AssertionError:
return string
def remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def strip_string(string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "") # noqa: W605
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = fix_a_slash_b(string)
return string
def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
try:
ss1 = strip_string(str1)
ss2 = strip_string(str2)
#pdb.set_trace()
if verbose:
print(ss1, ss2)
return ss1 == ss2
except Exception:
return str1 == str2
class NotEqual:
def __eq__(self, other):
return False

View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Neel Jain
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,50 @@
# NEFTune
[10/17/2023] NEFTune has been integrated into the Huggingface's TRL (Transformer Reinforcement Learning) library and HF Trainer . [See Annoucement](https://x.com/younesbelkada/status/1714283468790935687?s=20).
[11/25/2023] NEFTune has been integrated into [Ludwig.ai](https://ludwig.ai) for LLM fine-tuning. [See PR](https://github.com/ludwig-ai/ludwig/pull/3744).
Please see the [limitations](https://github.com/neelsjain/NEFTune#limitations) of our study below. Additionally, for generation, we suggest using greedy decoding with a repetition penalty of $1.2$. Note that without the repetition penalty that we have seen performance degrade and generations to degenerate.
Feel free to reach out to me as well (email at the bottom of the page 1 in the paper).
## Overview
In this paper, we propose to add random noise to the embedding vectors of the training data during the forward pass of fine-tuning. We show that this simple trick can improve the outcome of instruction fine-tuning, often by a large margin, with no additional compute or data overhead. <u>N</u>oisy <u>E</u>mbedding Instruction <u>F</u>ine <u>T</u>uning (NEFTune), while simple, has a strong impact on downstream conversational quality. When a raw LLM like LLaMA-2-7B is finetuned with noisy embeddings with popular Alpaca dataset, its performance on *AlpacaEval* improves from 29.8\% to 64.7\% -- an impressive boost of around 35 percentage points. NEFTune leads to this surprising and large jump in performance on conversational tasks, maintaining performance on factual question answering baselines. Using noisy embeddings seems to be a free lunch for LLM fine-tuning. The paper can be found [here](https://arxiv.org/abs/2310.05914).
<p align="center">
<img src=imgs/AlpacaEval_Figue1.png width="80%">
</p>
Note: During training, we observed the training loss values may be very close. The loss histograms in the Analysis section in the paper are the loss values when the noise is turned off. This is where the difference should show up.
## Code
The easiest way to incorporate NEFTune into your training procedure is to rewrite the forward for the embedding. An example of one way to do this for LLaMA is provided below. Note different distributed training will require different implementations.
```
from torch.nn import functional as F
def NEFTune(model, noise_alpha=5)
def noised_embed(orig_embed, noise_alpha):
def new_func(x):
# during training, we add noise to the embedding
# during generation, we don't add noise to the embedding
if model.training:
embed_init = orig_embed(x)
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
mag_norm = noise_alpha/torch.sqrt(dims)
return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm)
else:
return orig_embed(x)
return new_func
##### NOTE: this is for a LLaMA model #####
##### For a different model, you need to change the attribute path to the embedding #####
model.base_model.model.model.embed_tokens.forward = noised_embed(model.base_model.model.model.embed_tokens, noise_alpha)
return model
```
The code we used to run the experiments can be found in the `experiment_code` folder.
## Limitations
Our study has several limitations. We adopt AlpacaEval as our central measure of instruction following ability for LLMs, which is subject to the biases of a single judge (GPT-4). Additionally, due to limited computing resources, we were not able to validate the success of NEFTune on larger 70B variants of LLaMA-2 on multiple datasets, and we had to rely on fixed hyper-parameters for most NEFTune runs rather than sweeping. Finally, despite our empirical studies, we do not have a conclusive understanding of why NEFTune works.
## Call for Feedback
Our study was limited to the settings that we explored. Please feel free to open an issue regarding any weakness of NEFTune. We hope to be open about any issues with NEFTune to help future research and users.

View File

@ -0,0 +1,10 @@
logs/
data_instruct/
__pycache__/
wandb/
.vscode/
outputs/
results/
log/
dummy_script.sh
checkpoints/

View File

@ -0,0 +1,55 @@
# QuickStart
First, you want to install the environment (assuming that you have conda installed)
```
conda create -n hug python=3.11
pip install -r requirements.txt
```
Please use the pytorch version that is specified in the requirements. Otherwise, this may cause some problems when loading in the model in `train.py`.
# Converting checkpoints from huggingface to fsdp
The `convert_hf_to_fsdp.py` converts huggingface checkpoint to one that can be loaded by fsdp. After conversion, the model can be loaded in a distributed manner consuming much less memory. Usually, when loading the hugging face model to N GPUs, one needs to first realize N models in CPU memory before moving the model to GPUs. This can easily blow out the CPU memory if the model is large. You can convert the model by running the command below. We ran these experiments over 4xA5000. Each A5000 has a GPU memory of 24GB.
```
python convert_hf_to_fsdp.py --load_path $HF_CHECKPOINT_PATH --save_path $SAVE_PATH --add tokens $NUM_TOKENS
# `$NUM_TOKENS` is the number of new tokens that one wants to add to the pretrained model. In the case of llama, we add an additional padding token since it doesn't have one originally. For opt, we don't need to add new tokens, since it already contains all special tokens.
```
If you want to use a model that is on the huggingface hub, you can run the command below. We will use Llama-2-7b-hf as an example
```
SAVE_PATH_SHARDED=pretrained_models/Llama2_7b_sharded
SAVE_PATH_HF=pretrained_models/Llama2_7b_hf
python convert_hf_to_fsdp.py --load_path meta-llama/Llama-2-7b-hf \
--save_path $SAVE_PATH_SHARDED \
--save_path_hf $SAVE_PATH_HF
```
The script above will save a sharded version of the model in `$SAVE_PATH_SHARDED` and a huggingface checkpoint at `$SAVE_PATH_HF`. The sharded file only contains the weight, and the huggingface checkpoint is still needed to initialize the architecture and tokenizer.
# Training
Vanilla Fine-tuning
```
python train.py --init_checkpoint_path <fsdp_model_path> \
--model_config_path <hf_model_path> --wrapped_class_name LlamaDecoderLayer \
--data_path datasets/alpaca-train.jsonl --added_tokens 1 \
--act_checkpointing --lr 5e-5 --accumulation_steps 8 --batch_size 4 \
--checkpoint_path ./checkpoints/naive --hack --wandb --wb_name naive
```
NEFTune with a noise magnitude of 5
```
python train.py --init_checkpoint_path <fsdp_model_path> \
--model_config_path <hf_model_path> --wrapped_class_name LlamaDecoderLayer \
--data_path datasets/alpaca-train.jsonl --added_tokens 1 \
--act_checkpointing --lr 5e-5 --accumulation_steps 8 --batch_size 4 \
--checkpoint_path ./checkpoints/neftune --hack --wandb --wb_name neftune \
--neftune_alpha 5
```
# Evaluation
You may use the script here: `scripts/alpaca_eval.sh`
# Acknowledgement (Code)
A big thank you to [Ping](https://github.com/Ping-C) for developing the foundations of this code. Also, thank you to the Alpaca and FastChat projects as well, which were vital in the development of this code.

View File

@ -0,0 +1,500 @@
"""
Conversation prompt templates.
"""
import dataclasses
from enum import auto, Enum
from typing import List, Any, Dict
class SeparatorStyle(Enum):
"""Separator styles."""
ADD_COLON_SINGLE = auto()
ADD_COLON_TWO = auto()
ADD_COLON_SPACE_SINGLE = auto()
NO_COLON_SINGLE = auto()
ADD_NEW_LINE_SINGLE = auto()
DOLLY = auto()
RWKV = auto()
PHOENIX = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
# The name of this template
name: str
# The system prompt
system: str
# Two roles
roles: List[str]
# All messages. Each item is (role, message).
messages: List[List[str]]
# The number of few shot examples
offset: int
# Separators
sep_style: SeparatorStyle
sep: str
sep2: str = None
# Stop criteria (the default one is EOS token)
stop_str: str = None
# Stops generation if meeting any token in this list
stop_token_ids: List[int] = None
def get_prompt(self) -> str:
"""Get the prompt for generation."""
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ": " # must be end with a space
return ret
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
ret = self.system
for role, message in self.messages:
if message:
ret += role + message + self.sep
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + "\n" + message + self.sep
else:
ret += role + "\n"
return ret
elif self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ":\n" + message + seps[i % 2]
if i % 2 == 1:
ret += "\n\n"
else:
ret += role + ":\n"
return ret
elif self.sep_style == SeparatorStyle.RWKV:
ret = self.system
for i, (role, message) in enumerate(self.messages):
if message:
ret += (
role
+ ": "
+ message.replace("\r\n", "\n").replace("\n\n", "\n")
)
ret += "\n\n"
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.PHOENIX:
ret = self.system
for role, message in self.messages:
if message:
ret += role + ": " + "<s>" + message + "</s>"
else:
ret += role + ": " + "<s>"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role: str, message: str):
"""Append a new message."""
self.messages.append([role, message])
def update_last_message(self, message: str):
"""Update the last output.
The last message is typically set to be None when constructing the prompt,
so we need to update it in-place after getting the response from a model.
"""
self.messages[-1][1] = message
def to_gradio_chatbot(self):
"""Convert the conversation to gradio chatbot format"""
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def to_openai_api_messages(self):
"""Convert the conversation to OpenAI chat completion format."""
ret = [{"role": "system", "content": self.system}]
for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append({"role": "user", "content": msg})
else:
if msg is not None:
ret.append({"role": "assistant", "content": msg})
return ret
def copy(self):
return Conversation(
name=self.name,
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
stop_str=self.stop_str,
stop_token_ids=self.stop_token_ids,
)
def dict(self):
return {
"name": self.name,
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
}
# A global registry for all conversation templates
conv_templates: Dict[str, Conversation] = {}
def register_conv_template(template: Conversation, override: bool = False):
"""Register a new conversation template."""
if not override:
assert template.name not in conv_templates, f"{name} has been registered."
conv_templates[template.name] = template
def get_conv_template(name: str) -> Conversation:
"""Get a conversation template."""
return conv_templates[name].copy()
# A template with one conversation example
register_conv_template(
Conversation(
name="one_shot",
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
(
"Human",
"Got any creative ideas for a 10 year olds birthday?",
),
(
"Assistant",
"""Of course! Here are some creative ideas for a 10-year-old's birthday party:
1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises.
2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions.
3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars.
4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors.
5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants.
6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen.
7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges.
8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors.
Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""",
),
),
offset=2,
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
sep="\n### ",
stop_str="###",
)
)
# Vicuna v1.1 template
register_conv_template(
Conversation(
name="vicuna_v1.1",
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_COLON_TWO,
sep=" ",
sep2="</s>",
)
)
# Koala default template
register_conv_template(
Conversation(
name="koala_v1",
system="BEGINNING OF CONVERSATION:",
roles=("USER", "GPT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_COLON_TWO,
sep=" ",
sep2="</s>",
)
)
# Alpaca default template
register_conv_template(
Conversation(
name="alpaca",
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
roles=("### Instruction:", "### Response:"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
sep="\n\n",
)
)
# Dolly V2 default template
register_conv_template(
Conversation(
name="dolly_v2",
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
roles=("### Instruction", "### Response"),
messages=(),
offset=0,
sep_style=SeparatorStyle.DOLLY,
sep="\n\n",
sep2="### End",
)
)
# OpenAssistant Pythia default template
register_conv_template(
Conversation(
name="oasst_pythia",
system="",
roles=("<|prompter|>", "<|assistant|>"),
messages=(),
offset=0,
sep_style=SeparatorStyle.NO_COLON_SINGLE,
sep="<|endoftext|>",
)
)
# StableLM Alpha default template
register_conv_template(
Conversation(
name="stablelm",
system="""<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
""",
roles=("<|USER|>", "<|ASSISTANT|>"),
messages=(),
offset=0,
sep_style=SeparatorStyle.NO_COLON_SINGLE,
sep="",
stop_token_ids=[50278, 50279, 50277, 1, 0],
)
)
# Baize default template
register_conv_template(
Conversation(
name="baize",
system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n",
roles=("[|Human|]", "[|AI|]"),
messages=(
("[|Human|]", "Hello!"),
("[|AI|]", "Hi!"),
),
offset=2,
sep_style=SeparatorStyle.NO_COLON_SINGLE,
sep="\n",
stop_str="[|Human|]",
)
)
# RWKV-4-Raven default template
register_conv_template(
Conversation(
name="rwkv",
system="",
roles=("Bob", "Alice"),
messages=(
("Bob", "hi"),
(
"Alice",
"Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.",
),
),
offset=2,
sep_style=SeparatorStyle.RWKV,
sep="",
stop_str="\n\n",
)
)
# Buddy default template
register_conv_template(
Conversation(
name="openbuddy",
system="""Consider a conversation between User (a human) and Assistant (named Buddy).
Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy
Buddy cannot access the Internet.
Buddy can fluently speak the user's language (e.g. English, Chinese).
Buddy can generate poems, stories, code, essays, songs, parodies, and more.
Buddy possesses vast knowledge about the world, history, and culture.
Buddy's responses are always safe, creative, high-quality, human-like, and interesting.
Buddy strictly refuses to discuss political, NSFW, or other unsafe topics.
User: Hi.
Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""",
roles=("User", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
sep="\n",
)
)
# Phoenix default template
register_conv_template(
Conversation(
name="phoenix",
system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.PHOENIX,
sep="</s>",
)
)
# ChatGPT default template
register_conv_template(
Conversation(
name="chatgpt",
system="You are a helpful assistant.",
roles=("user", "assistant"),
messages=(),
offset=0,
sep_style=None,
sep=None,
)
)
# Claude default template
register_conv_template(
Conversation(
name="claude",
system="",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
sep="\n\n",
)
)
# MPT default template
register_conv_template(
Conversation(
name="mpt",
system="""<|im_start|>system
- You are a helpful assistant chatbot trained by MosaicML.
- You answer questions.
- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.
""",
roles=("<|im_start|>user", "<|im_start|>assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
sep="<|im_end|>",
stop_token_ids=[50278, 0],
)
)
# Bard default template
# Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150
# https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40
register_conv_template(
Conversation(
name="bard",
system="",
roles=("0", "1"),
messages=(),
offset=0,
sep_style=None,
sep=None,
)
)
# BiLLa default template
register_conv_template(
Conversation(
name="billa",
system="",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE,
sep="\n",
stop_str="Human:",
)
)
# RedPajama INCITE default template
register_conv_template(
Conversation(
name="redpajama-incite",
system="",
roles=("<human>", "<bot>"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
sep="\n",
stop_str="<human>",
)
)
# h2oGPT default template
register_conv_template(
Conversation(
name="h2ogpt",
system="",
roles=("<|prompt|>", "<|answer|>"),
messages=(),
offset=0,
sep_style=SeparatorStyle.NO_COLON_SINGLE,
sep="</s>",
)
)
if __name__ == "__main__":
conv = get_conv_template("vicuna_v1.1")
conv.append_message(conv.roles[0], "Hello!")
conv.append_message(conv.roles[1], "Hi!")
conv.append_message(conv.roles[0], "How are you?")
conv.append_message(conv.roles[1], None)
print(conv.get_prompt())

View File

@ -0,0 +1,29 @@
3"""Convert hf model to checkpoint consummable by fsdp"""
import argparse
from transformers.models.llama.modeling_llama import LlamaForCausalLM
import torch.distributed._shard.checkpoint as dist_cp
import torch
import transformers
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--load_path", type=str, default="llama/7B_sharded")
parser.add_argument("--save_path", type=str, default="llama/7B_new_hf")
parser.add_argument("--added_tokens", type=int, default=1, help="Number of tokens added to the model that need to be ")
parser.add_argument("--config_path", type=str, default="llama/7B_hf")
args = parser.parse_args()
model_config = transformers.AutoConfig.from_pretrained(args.config_path)
model = LlamaForCausalLM(model_config).bfloat16()
if args.added_tokens > 0:
model.resize_token_embeddings(model.config.vocab_size + args.added_tokens)
state_dict = model.state_dict()
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=dist_cp.FileSystemReader(args.load_path),
no_dist=True
)
model.load_state_dict(state_dict)
model.save_pretrained(args.save_path)

View File

@ -0,0 +1,41 @@
"""Convert hf model to checkpoint consummable by fsdp"""
import argparse
import transformers
import torch.distributed._shard.checkpoint as dist_cp
from utils import make_nonpersistent_buffer_persistent
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--load_path", type=str, default="llama/7B_hf")
parser.add_argument("--save_path", type=str, default="llama/7B_sharded")
parser.add_argument("--save_path_hf", type=str, default=None, help="This is the path to save the model in HF format, is optional")
parser.add_argument("--add_tokens", type=int, default=0, help="Number of additional tokens to add to the model")
parser.add_argument("--cache_dir", type=str, default=None, help="This can be used to store the HF model in a different location than the default if using hf path as opposed to local directory")
args = parser.parse_args()
model = transformers.AutoModelForCausalLM.from_pretrained(args.load_path, cache_dir=args.cache_dir)
model = model.to(model.config.torch_dtype) # from_pretrained does not load model weights to the default type, so we have to do it manually
if args.add_tokens > 0:
model.resize_token_embeddings(model.config.vocab_size + args.add_tokens)
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-args.add_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-args.add_tokens].mean(dim=0, keepdim=True)
input_embeddings[-args.add_tokens:] = input_embeddings_avg
output_embeddings[-args.add_tokens:] = output_embeddings_avg
print('added tokens')
# save the huggingface config/weights/tokenizers if a path is specified,
# this is only necessary weight load_path is pointing to a huggingface model
if args.save_path_hf is not None:
model.save_pretrained(args.save_path_hf)
transformers.AutoTokenizer.from_pretrained(args.load_path, cache_dir=args.cache_dir).save_pretrained(args.save_path_hf)
# by making nonpersistent buffer persistent, state_dict now includes the original nonpersistent buffer values,
# which can be used to override the incorrect nonpersistent buffer when initializing the model directly on gpu
make_nonpersistent_buffer_persistent(model)
dist_cp.save_state_dict(
state_dict=model.state_dict(),
storage_writer=dist_cp.FileSystemWriter(args.save_path),
no_dist=True
)

View File

@ -0,0 +1,242 @@
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import json
import copy
import logging
import random
from dataclasses import dataclass
from typing import Optional, Dict, Sequence
import os
import torch
import transformers
from torch.utils.data import Dataset
mask_probability = os.environ.get('mask_p')
dec = os.environ.get('dec')
random_mask_p = os.environ.get('random_mask_p')
warmup_steps = os.environ.get('warmup_steps')
use_random = os.environ.get('use_random')
if use_random is not None:
use_random = True
if random_mask_p is not None:
random_mask_p = True
else:
random_mask_p = False
if warmup_steps is None:
warmup_steps = 0
else:
warmup_steps = int(warmup_steps)
if mask_probability is None:
mask_probability = 0
else:
mask_probability = float(mask_probability)
if dec is not None:
dec = True
else:
dec = False
IGNORE_INDEX = -100
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
def _make_r_io_base(f, mode: str):
if not isinstance(f, io.IOBase):
f = open(f, mode=mode)
return f
def jload(f, mode="r"):
"""Load a .json file into a dictionary."""
f = _make_r_io_base(f, mode)
jdict = json.load(f)
f.close()
return jdict
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def mask_target_tokens(input_ids, sources_tokenized, mask_probability, MASK_INDEX, tokenizer):
masked_input_ids = copy.deepcopy(input_ids)
vocab_size = int(tokenizer.vocab_size)
for input_id, source_len in zip(masked_input_ids, sources_tokenized["input_ids_lens"]):
for i in range(source_len, len(input_id)):
if random.random() < mask_probability:
if use_random:
input_id[i] = random.randint(0, vocab_size-1)
input_id[i] = MASK_INDEX
return masked_input_ids
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
steps
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
if mask_probability != 0:
if warmup_steps > 0 and steps < warmup_steps:
_mask_probability = steps / warmup_steps * mask_probability
else:
_mask_probability = mask_probability
print(_mask_probability)
# MASK_INDEX = tokenizer.convert_tokens_to_ids(['<mask>'])[0]
MASK_INDEX = 0
# print(MASK_INDEX, tokenizer.pad_token_id)
# assert 1==0
masked_input_ids = mask_target_tokens(input_ids, sources_tokenized, _mask_probability, MASK_INDEX, tokenizer)
input_ids = masked_input_ids
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
from io_utils import read_jsonlines
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_fraction: float=1.0, seed: int=42):
super().__init__()
logging.warning("Loading data...")
if "dolly" in data_path:
list_data_dict = read_jsonlines(data_path)
list_data_dict = list(list_data_dict)
list_data_dict = [{"instruction": data_dict["instruction"],
"input": data_dict["context"],
"output": data_dict["response"]} for data_dict in list_data_dict]
else:
list_data_dict = jload(data_path)
used_data_count = int(len(list_data_dict)*data_fraction)
print(f"using {used_data_count} data out of {len(list_data_dict)}")
random.seed(seed)
random.shuffle(list_data_dict)
list_data_dict = list_data_dict[:used_data_count]
logging.warning("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
self.sources = sources
self.targets = targets
# logging.warning("Tokenizing inputs... This may take some time...")
# data_dict = preprocess(sources, targets, tokenizer)
# self.input_ids = data_dict["input_ids"]
# self.labels = data_dict["labels"]
def __len__(self):
return len(self.sources)
return len(self.input_ids)
# def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# return dict(input_ids=self.input_ids[i], labels=self.labels[i])
def __getitem__(self, i):
return dict(input_ids=self.sources[i], labels=self.targets[i])
@dataclass
class DataCollatorForSupervisedDataset:
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
steps: int = 0
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
sources = []
targets = []
for instance in instances:
source = instance['input_ids']
target = instance['labels']
sources.append(source)
targets.append(target)
self.steps += 1
data_dict = preprocess(sources, targets, self.tokenizer, self.steps)
input_ids, labels = data_dict['input_ids'], data_dict['labels']
# input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_path, data_fraction: float=1.0, seed: int=42) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_path, data_fraction=data_fraction, seed=seed)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)

View File

@ -0,0 +1,126 @@
"""This script generates the evaluation responses that can e used by eval_scoring.py"""
import argparse
from functools import partial
import json
from datasets import Dataset
import datasets
import transformers
from transformers.models.llama.configuration_llama import LlamaConfig
import torch
from io_utils import load_jsonlines
from utils import load_fsdp_ckpt_with_accelerate, add_padding_token
from conversation import get_conv_template
def apply_conv_template(example, template_type):
# preprocess instructions into prompted inputs
conv = get_conv_template(template_type)
conv.append_message(conv.roles[0], example['instruction'])
conv.append_message(conv.roles[1], "")
prompt = conv.get_prompt()
example.update({
"prompt": prompt
})
return example
def generate_responses_batched(example, model, tokenizer, kwargs):
prompt = example['prompt']
print(prompt)
encoding = tokenizer(prompt,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
encoding = encoding.to(model.device)
with torch.no_grad():
model_output = model.generate(**encoding, **kwargs)
input_len = encoding.input_ids.shape[-1]
model_output = model_output[:, input_len:].cpu()
decoded_output = tokenizer.batch_decode(model_output, skip_special_tokens=True, clean_up_tokenization_spaces=False)
del example['prompt']
example.update({"output": decoded_output})
example.update({"metadata": [kwargs] * len(decoded_output)})
return example
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="llama/7B_sharded", type=str)
parser.add_argument("--model_name", default=None, type=str)
parser.add_argument("--model_config_path", default="llama/7B_hf", type=str)
parser.add_argument("--template_type", default="alpaca", type=str)
parser.add_argument("--file_path", default="datasets/self-instruct-val(processed).jsonl", type=str)
parser.add_argument("--save_file_name", default="outputs/answers/self-instruct_llama7B.jsonl", type=str)
parser.add_argument("--batch_size", default=4, type=int)
parser.add_argument("--debug", action='store_true', help="This reduce the number of generation examples to 4, so that we can debug faster.")
args = parser.parse_args()
model_config = transformers.AutoConfig.from_pretrained(args.model_config_path)
if isinstance(model_config, LlamaConfig):
model_config.vocab_size += 1 # hardcode the vocab size for llama...
model = load_fsdp_ckpt_with_accelerate(args.model, model_config, hf_dummy_path=args.model_config_path, wrapped_class="LlamaDecoderLayer" if 'llama' in args.model else "OPTDecoderLayer")
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.model_config_path,
model_max_length=2048,
padding_side="left",
use_fast=False,
)
add_padding_token(tokenizer)
## set the models to eval mode
model = model.eval()
if "dolly" in args.file_path or "vicuna" in args.file_path or "user_oriented_instructions" in args.file_path:
tasks = load_jsonlines(args.file_path)
raw_data = Dataset.from_list(tasks)
if "dolly" in args.file_path:
question_sources = "dolly"
elif "vicuna" in args.file_path:
question_sources = "vicuna"
raw_data = raw_data.rename_column("text", "instruction")
elif "user_oriented_instructions" in args.file_path:
question_sources = "alpaca"
elif "alpaca_eval" in args.file_path:
raw_data = datasets.load_dataset("tatsu-lab/alpaca_eval", "alpaca_eval")["eval"]
# set generator field to model name
raw_data = raw_data.map(lambda x: {"generator": args.model_name if args.model_name else args.model})
# reduce number of examples for debugging
if args.debug:
raw_data = raw_data.select(range(4))
## preprocess
eval_preproc = partial(apply_conv_template, template_type=args.template_type)
raw_data = raw_data.map(eval_preproc)
## run generation
generate_kwargs = dict(max_length=2048, do_sample=False, top_p=0.1,
num_return_sequences=1, temperature=0.1,
repetition_penalty=1.2)
generate = partial(generate_responses_batched,
model=model,
tokenizer=tokenizer,
kwargs=generate_kwargs)
dataset_w_responses = raw_data.map(generate,
batched=True,
batch_size=args.batch_size)
if "alpaca_eval" in args.file_path:
# stringify the metadata, so that it becomes hashable
dataset_w_responses = dataset_w_responses.map(lambda x: {"metadata": json.dumps(x["metadata"])})
dataset_w_responses.to_json(args.save_file_name, orient="records", lines=False, indent=True)
else:
dataset_w_responses.to_json(args.save_file_name)

View File

@ -0,0 +1,43 @@
import dataclasses
import logging
import math
import os
import io
import sys
import time
import json
from typing import Optional, Sequence, Union, Any, Mapping, Iterable, Union, List, Callable
import openai
import tqdm
from openai import openai_object
import copy
def read_jsonlines(filename: str) -> Iterable[Mapping[str, Any]]:
"""Yields an iterable of Python dicts after reading jsonlines from the input file."""
file_size = os.path.getsize(filename)
with open(filename) as fp:
for line in tqdm.tqdm(fp.readlines(), desc=f"Reading JSON lines from {filename}", unit="lines"):
try:
example = json.loads(line)
yield example
except json.JSONDecodeError as ex:
logging.error(f'Input text: "{line}"')
logging.error(ex.args)
raise ex
def load_jsonlines(filename: str) -> List[Mapping[str, Any]]:
"""Returns a list of Python dicts after reading jsonlines from the input file."""
return list(read_jsonlines(filename))
def write_jsonlines(
objs: Iterable[Mapping[str, Any]], filename: str, to_dict: Callable = lambda x: x
):
"""Writes a list of Python Mappings as jsonlines at the input file."""
with open(filename, "w") as fp:
for obj in tqdm.tqdm(objs, desc=f"Writing JSON lines at {filename}"):
fp.write(json.dumps(to_dict(obj)))
fp.write("\n")

View File

@ -0,0 +1,28 @@
# this file remove the instance field from user oriented instructions and add the input field to the instruction field
from io_utils import load_jsonlines
from datasets import Dataset
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--file_path", default="datasets/self-instruct-val", type=str)
parser.add_argument("--save_file_name", default="datasets/self-instruct-val(processed)", type=str)
args = parser.parse_args()
tasks = load_jsonlines(args.file_path)
tasks_processed = []
for task in tasks:
if 'instances' in task:
task_input = task['instances'][0]['input']
if task_input:
task['instruction'] += f"\n\n### Input:\n{task_input}"
del task['instances']
elif "context" in task:
context = task['context']
if context:
task['instruction'] += f"\n\n### Input:\n{task['context']}"
del task['context']
tasks_processed.append(task)
Dataset.from_list(tasks_processed).to_json(args.save_file_name)

View File

@ -0,0 +1,78 @@
"""This file shoud profile both the gpu usage and ram usage of loading a sharded model."""
import os
import json
import argparse
import torch
import torch.multiprocessing as mp
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
from transformers import AutoConfig
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from memory_profiler import memory_usage
from utils import (
get_gpu_memory_usage, fix_rotary_embedding_module_in_fsdp,
get_fsdp_wrapped_empty_model, load_state_dict_fsdp,
setup, cleanup
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--sharded_model_path", type=str, default="llama/7B_sharded")
parser.add_argument("--config_path", type=str, default="llama/7B_hf")
parser.add_argument("--world_size", type=int, default=4)
parser.add_argument("--num_tokens", type=int, default=1024)
parser.add_argument("--offload_to_cpu", action="store_true", help="Offload state dictionary to cpu before loading to gpu")
parser.add_argument("--act_checkpoint", action="store_true", help="Use activation checkpointing to reduce the memory the model")
args = parser.parse_args()
return args
def profile_main(rank, world_size, args):
setup(rank, world_size)
torch.cuda.set_device(rank)
profile_results = {}
# 1. profile ram usage and gpu memory usage when loading the model
model_loaded = None
def load_model():
nonlocal model_loaded
model_config = AutoConfig.from_pretrained(args.config_path)
model_empty = get_fsdp_wrapped_empty_model(model_config, LlamaDecoderLayer)
model_empty_fixed = fix_rotary_embedding_module_in_fsdp(model_empty, model_config)
model_loaded = load_state_dict_fsdp(model_empty_fixed, args.sharded_model_path, offload_to_cpu=args.offload_to_cpu)
max_ram_usage_gb = round(memory_usage(load_model, interval=0.1, max_usage=True)/1024, 2)
profile_results['max_ram_usage_gb'] = max_ram_usage_gb
profile_results['max_gpu_memory_usage_after_loading_model'] = get_gpu_memory_usage(rank)['max']
# 1.5. apply activation checkpointing if specified, this will significantly reduce the memory usage of the model
if args.act_checkpoint:
check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)
apply_activation_checkpointing(
model_loaded, check_fn=check_fn
)
# 2. profile gpu memory usage when forwarding through the model without gradient
mock_input_ids = torch.arange(0, args.num_tokens, dtype=torch.long)[None, :]
with torch.no_grad():
out = model_loaded(mock_input_ids)
profile_results['max_gpu_memory_usage_after_nograd_forward'] = get_gpu_memory_usage(rank)['max']
# 3. profile gpu memory usage when forwarding through the model with gradient
out = model_loaded(mock_input_ids)
profile_results['max_gpu_memory_usage_after_withgrad_forward'] = get_gpu_memory_usage(rank)['max']
# 4. profile gpu memory usage when backwarding through the model
out.logits.sum().backward()
profile_results['max_gpu_memory_usage_after_backward'] = get_gpu_memory_usage(rank)['max']
print(f"rank: {rank} {json.dumps(profile_results, indent=4)}")
cleanup()
if __name__ == '__main__':
args = get_args()
mp.spawn(profile_main,
args=(args.world_size, args),
nprocs=args.world_size,
join=True)

View File

@ -0,0 +1,14 @@
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.0.1+cu118
torchvision==0.15.2+cu118
torchaudio
datasets==2.12.0
transformers==4.30.1
sentencepiece==0.1.99
protobuf==3.20.0
accelerate==0.20.3
shortuuid
wandb
lion-pytorch
openai==0.27.8
git+https://github.com/tatsu-lab/alpaca_eval.git

View File

@ -0,0 +1,18 @@
# generating answers for alpaca_eval
MODEL_PATH=checkpoints/naive/1217/model
MODEL_NAME=llama7B_Alpaca
MODEL_CONFIG=models/llama7B/
ANSWER_FILE=outputs/answers/alpaca_eval_${MODEL_NAME}.json
python eval_generate.py --model $MODEL_PATH --file_path alpaca_eval --save_file_name $ANSWER_FILE --model_name $MODEL_NAME --model_config_path $MODEL_CONFIG
alpaca_eval --model_outputs $ANSWER_FILE --name $MODEL_NAME --annotators_config chatgpt_fn --caching_path=outputs/cached_annotations.json # gpt3.5 evaluation
alpaca_eval --model_outputs $ANSWER_FILE --name $MODEL_NAME --annotators_config alpaca_eval_gpt4 --caching_path=outputs/cached_annotations.json # gpt4 evaluation
# outputs/cached_annotations.json contains all prior annotations done, so that we don't have to reannotate the same examples.
# specifically, the annotator check whether the pair of examples + annootator type exists in the cached_annotations.json file, and if so, it will not reannotate.
echo "=================================================="
echo "Below is the leaderboard with ChatGPT as annotator"
echo "=================================================="
alpaca_eval --annotators_config chatgpt_fn --leaderboard_mode_to_print=None
echo "=================================================="
echo "Below is the leaderboard with GPT4 as annotator"
echo "=================================================="
alpaca_eval --annotators_config alpaca_eval_gpt4 --leaderboard_mode_to_print=None

View File

@ -0,0 +1,237 @@
import os
import time
import argparse
import json
import random
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
from torch.distributed.elastic.multiprocessing.errors import record
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import transformers
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.trainer_utils import seed_worker
from transformers.optimization import get_cosine_schedule_with_warmup
from lion_pytorch import Lion
from dataset import make_supervised_data_module
from accelerate.data_loader import skip_first_batches
import wandb
from utils import (get_fsdp_wrapped_empty_model, load_model_opt_scheduler_states_fsdp,
load_state_dict_fsdp, save_model_opt_scheduler_states_fsdp,
add_padding_token
)
def setup(rank, world_size, port):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = port
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def get_empty_model(model_config_path, add_tokens=1, wrapped_class=None, hack=False):
model_config = transformers.AutoConfig.from_pretrained(model_config_path)
model_config.vocab_size += add_tokens
return get_fsdp_wrapped_empty_model(model_config, wrapped_class, hack=hack)
def get_model_opt_scheduler(added_tokens, model_config_path, max_steps=1000, warmup_ratio=0.03, weight_decay=0.0, lr=2e-5, wrapped_class=None, hack=False):
model = get_empty_model(model_config_path, add_tokens=added_tokens, wrapped_class=wrapped_class, hack=hack)
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = get_cosine_schedule_with_warmup(opt, int(max_steps*warmup_ratio), num_training_steps=max_steps)
return model, opt, scheduler
def get_dataloader_and_sampler(train_dataset, data_collator, batch_size, rank, world_size=4):
sampler = DistributedSampler(
train_dataset,
num_replicas=world_size,
rank=rank,
seed=0,
)
return DataLoader(
train_dataset,
batch_size=batch_size,
collate_fn=data_collator,
sampler=sampler,
drop_last=True,
num_workers=0,
pin_memory=True,
worker_init_fn=seed_worker,
), sampler
def get_class_from_class_name(class_name):
if class_name == "LlamaDecoderLayer":
return LlamaDecoderLayer
elif class_name == "OPTDecoderLayer":
return OPTDecoderLayer
else:
raise ValueError(f"Unknown class name {class_name}")
@record
def fsdp_main(rank, world_size, args):
setup(rank, world_size, args.port)
if rank == 0:
if args.wandb:
wandb.init(project=args.wb_project, name=args.wb_name, config=args, resume=args.resume)
torch.cuda.set_device(rank)
wrapped_class = get_class_from_class_name(args.wrapped_class_name)
model, opt, scheduler = get_model_opt_scheduler(
added_tokens=args.added_tokens,
model_config_path=args.model_config_path,
max_steps=args.max_steps, warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay, lr=args.lr,
wrapped_class=wrapped_class, hack=args.hack)
if args.resume:
model, opt, scheduler, start_step_count = load_model_opt_scheduler_states_fsdp(model, opt, scheduler, args.checkpoint_path)
else:
model = load_state_dict_fsdp(model, args.init_checkpoint_path)
start_step_count = 0
if args.act_checkpointing:
check_fn = lambda submodule: isinstance(submodule, wrapped_class)
apply_activation_checkpointing(
model, check_fn=check_fn
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.model_config_path,
model_max_length=512,
padding_side="right",
use_fast=False,
)
add_padding_token(tokenizer)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_path=args.data_path, data_fraction=args.data_fraction, seed=args.sample_seed)
train_dataset = data_module['train_dataset']
data_collator = data_module['data_collator']
dataloader_full, sampler = get_dataloader_and_sampler(train_dataset=train_dataset, data_collator=data_collator, batch_size=args.batch_size, rank=rank, world_size=world_size)
# updating the dataloader to the right state
step_count = start_step_count
sub_step_count = step_count * args.accumulation_steps
start_epoch = sub_step_count // len(dataloader_full)
skip_steps = sub_step_count % len(dataloader_full)
sampler.set_epoch(start_epoch)
dataloader = skip_first_batches(dataloader_full, skip_steps)
print("start_step_count", start_step_count, "step_count", step_count, "epoch", start_epoch, "skip_steps", skip_steps)
accumulation_steps = args.accumulation_steps
save_steps = args.save_steps
epoch_iterator = iter(dataloader)
start_time = time.time()
for step_count in range(start_step_count, args.max_steps):
train_loss = 0
for _ in range(accumulation_steps):
try:
data = next(epoch_iterator)
except StopIteration:
sampler.set_epoch(sampler.epoch + 1)
dataloader = dataloader_full
epoch_iterator = iter(dataloader)
data = next(epoch_iterator)
if args.neftune_alpha is not None:
if isinstance(model, torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel):
embed_device = model._fsdp_wrapped_module.model.embed_tokens.weight.device
embeds_init = model._fsdp_wrapped_module.model.embed_tokens.forward(data['input_ids'].to(embed_device))
### add noise to embeds
input_mask = data['attention_mask'].to(embeds_init) # B x L
input_lengths = torch.sum(input_mask, 1) # B
noise_ = torch.zeros_like(embeds_init).uniform_(-1,1)
delta = noise_ * input_mask.unsqueeze(2)
dims = input_lengths * embeds_init.size(-1)
mag = args.neftune_alpha / torch.sqrt(dims)
delta = (delta * mag.view(-1, 1, 1)).detach()
data['inputs_embeds'] = delta + embeds_init
data['input_ids'] = None
### add noise to embeds
out = model(**data)
(out.loss/accumulation_steps).backward()
train_loss += out.loss.item()/accumulation_steps
model.clip_grad_norm_(args.max_grad_norm)
if rank == 0:
time_so_far = (time.time() - start_time)/ 3600
iteration_so_far = step_count - start_step_count
remaining_iterations = args.max_steps - step_count
estimated_time_per_iteration = time_so_far / (iteration_so_far+1)
remaining_time = estimated_time_per_iteration * remaining_iterations
previous_time = start_step_count * estimated_time_per_iteration
total_estimated_time = time_so_far + remaining_time + previous_time
metrics_dict = {"train/loss": train_loss, "train/learning_rate": scheduler.get_last_lr()[0], "train/global_step": step_count+1,
"train/time_so_far": time_so_far, "train/remaining_time": remaining_time,
"train/total_estimated_time": total_estimated_time,
"train/train_steps_per_second": 1/(estimated_time_per_iteration*3600),
"train/epoch": sampler.epoch}
if args.wandb:
wandb.log(metrics_dict, step=step_count)
print(json.dumps(metrics_dict, indent=4))
opt.step()
scheduler.step()
opt.zero_grad()
# save the model, optimizer, scheduler
if (step_count+1) % save_steps == 0 or (step_count+1) == args.max_steps:
if rank == 0:
print("saving checkpoint", step_count+1)
save_model_opt_scheduler_states_fsdp(model, opt, scheduler, step_count, args.checkpoint_path, rank, dont_save_opt=args.dont_save_opt,
keep_old_checkpoints=args.keep_old_checkpoints)
cleanup()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--init_checkpoint_path", type=str, default="llama/7B_sharded")
parser.add_argument("--model_config_path", type=str, default="llama/7B_hf")
parser.add_argument("--checkpoint_path", type=str, default="llama/7B_checkpoint")
parser.add_argument("--wrapped_class_name", type=str, choices=["LlamaDecoderLayer", "OPTDecoderLayer"], default="LlamaDecoderLayer",
help="the name of the class that is wrapped by the FSDP module")
parser.add_argument("--dont_save_opt",action='store_true', help="dont save optimizer and scheduler, this saves hard disk memory by trading off ability to resume the run")
parser.add_argument("--keep_old_checkpoints",action='store_true', help="keep the intermediate checkpoints during training")
parser.add_argument("--added_tokens", type=int, default=1)
parser.add_argument("--port", default=None)
parser.add_argument("--data_path", type=str, default="data_instruct/alpaca.json")
parser.add_argument("--data_fraction", type=float, default=1.0, help="fraction of data to use for training should be between 1 and 0")
parser.add_argument("--sample_seed", type=int, default=42, help="the random seed used for sampling a fraction of the data")
parser.add_argument("--resume", action='store_true')
parser.add_argument("--max_steps", type=int, default=52002*3//128)
parser.add_argument("--warmup_ratio", type=float, default=0.03)
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--hack", action='store_true',
help="This is a hack to reduce memory usage of the model by first casting the model to bf16 before moving to gpu"
", it uses less memory. However, it does not necessarily have the same training behavior as non-hacked version")
parser.add_argument("--max_grad_norm", type=float, default=1)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--act_checkpointing", action='store_true')
parser.add_argument("--save_steps", type=int, default=(52002*3/128)//10)
parser.add_argument("--accumulation_steps", type=int, default=32)
parser.add_argument("--neftune_alpha", type=float, default=None)
# wandb associated arguments
parser.add_argument("--wandb", action='store_true')
parser.add_argument("--wb_project", type=str, default="data_instruct")
parser.add_argument("--wb_name", type=str, default="test")
args = parser.parse_args()
WORLD_SIZE = torch.cuda.device_count()
if args.port is None:
args.port = str(random.randint(1024, 65353)) # randomly generate ports if not specified
mp.spawn(fsdp_main,
args=(WORLD_SIZE, args),
nprocs=WORLD_SIZE,
join=True)

View File

@ -0,0 +1,210 @@
"""This file contains utils for the repository"""
import functools
import torch
import functools
import warnings
import logging
import os
import shutil
import json
import torch
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
)
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardedStateDictConfig, MixedPrecision
from torch.distributed.fsdp.api import StateDictType
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import transformers
import torch.distributed._shard.checkpoint as dist_cp
# a filter to suppress type storage warnings
warnings.filterwarnings("ignore", message="TypedStorage is deprecated.*")
def get_gpu_memory_usage(rank):
return {
'total': round(torch.cuda.get_device_properties(rank).total_memory / (1024**3), 2),
'max': round(torch.cuda.max_memory_allocated(rank) / (1024**3), 2),
'reserved': round(torch.cuda.memory_reserved(rank) / (1024**3), 2),
'allocated': round(torch.cuda.memory_allocated(rank) / (1024**3), 2)
}
def get_fsdp_wrapped_empty_model(model_config, wrapped_cls, hack=False):
with init_empty_weights():
if hack:
model = transformers.AutoModelForCausalLM.from_config(model_config).bfloat16()
else:
model = transformers.AutoModelForCausalLM.from_config(model_config)
# this ensures that the nonpersistent buffer are overriden by the saved values, when loading the model
make_nonpersistent_buffer_persistent(model)
# hack to make the model wrappable by FSDP
model.reset_parameters = lambda: None
wrapped_cls.reset_parameters = lambda x: None
torch.nn.Embedding.reset_parameters = lambda x: None
# wrap the empty model inside of FSDP
my_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=set([wrapped_cls, torch.nn.Embedding]),
)
bf16 = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16)
model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy, device_id=torch.cuda.current_device(), mixed_precision=bf16)
return model
def load_fsdp_ckpt_with_accelerate(fsdp_path, model_config, hf_dummy_path, wrapped_class):
# the dummy path a checkpoint path with huggingface format the accelerate.load_checkpoint_and_dispatch uses
# in order to leverage the function, we have to provide this path, so the weights are moved to the right device without blowing up
# memories. The weights will later be overwritten by the checkpoint from fsdp_path.
# one requirement is that the hf_dummy_path has to have the same shape as the checkpoint in the fsdp_path
with init_empty_weights():
model_empty = transformers.AutoModelForCausalLM.from_config(model_config)
model_empty = model_empty.bfloat16()
model = load_checkpoint_and_dispatch(model_empty, hf_dummy_path, device_map="auto", no_split_module_classes=[wrapped_class, torch.nn.Embedding])
# this is a hack
# the model weights in hf_dummy_path may have a different vocab size than the desired fsdp model we are trying to
# load from fsdp_path, so we explicitly resize the token embeddings after the model has been loaded.
current_vocab_size_based_on_weight = model.get_output_embeddings().weight.shape[0]
if model.config.vocab_size != current_vocab_size_based_on_weight:
# this will retie the weights if it is supposed to be tied, but also causes the weight device to be weird
model.resize_token_embeddings(model.config.vocab_size)
print("device map used by accelerate:\n", json.dumps(model.hf_device_map, indent=4))
model = load_state_dict_fsdp(model, fsdp_path, offload_to_cpu=False, no_dist=True)
return model
def load_state_dict_fsdp(model, load_path, offload_to_cpu=True, no_dist=False):
if no_dist:
checkpoint = model.state_dict()
dist_cp.load_state_dict(
state_dict=checkpoint,
storage_reader=dist_cp.FileSystemReader(load_path),
no_dist=no_dist
)
model.load_state_dict(checkpoint)
else:
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT, ShardedStateDictConfig(offload_to_cpu=offload_to_cpu)):
checkpoint = model.state_dict()
dist_cp.load_state_dict(
state_dict=checkpoint,
storage_reader=dist_cp.FileSystemReader(load_path),
no_dist=no_dist
)
model.load_state_dict(checkpoint)
return model
def save_state_dict_fsdp(model, save_path, offload_to_cpu=True):
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT, ShardedStateDictConfig(offload_to_cpu=offload_to_cpu)):
checkpoint = model.state_dict()
dist_cp.save_state_dict(
state_dict=checkpoint,
storage_writer=dist_cp.FileSystemWriter(save_path),
)
return model
def save_model_to_fsdp_format(model, save_path):
with LogLevelContext(logging.ERROR):
warnings.filterwarnings("ignore", message="TypedStorage is deprecated.*")
dist_cp.save_state_dict(
state_dict=model.state_dict(),
storage_writer=dist_cp.FileSystemWriter(save_path),
no_dist=True
)
def save_opt_or_scheduler_fsdp(model, opt, save_path, rank, offload_to_cpu=True):
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT, ShardedStateDictConfig(offload_to_cpu=offload_to_cpu)):
torch.save(opt.state_dict(), os.path.join(save_path, f"shard{rank}.pt"))
def load_opt_or_scheduler_fsdp(model, opt, save_path, rank, offload_to_cpu=True):
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT, ShardedStateDictConfig(offload_to_cpu=offload_to_cpu)):
state_dict = torch.load(os.path.join(save_path, f"shard{rank}.pt"))
opt.load_state_dict(state_dict)
def save_model_opt_scheduler_states_fsdp(model, opt, scheduler, step_count, checkpoint_path, rank, dont_save_opt=False, keep_old_checkpoints=False):
# TODO: remove rank arguments in this function
# for component, name in [[model, "model"], [opt, "opt"], [scheduler, "scheduler"]]:
path = os.path.join(checkpoint_path, str(step_count), "model")
save_state_dict_fsdp(model, path)
if not dont_save_opt:
path = os.path.join(checkpoint_path, str(step_count), "opt")
os.makedirs(path, exist_ok=True)
save_opt_or_scheduler_fsdp(model, opt, path, rank)
path = os.path.join(checkpoint_path, str(step_count), "scheduler")
os.makedirs(path, exist_ok=True)
save_opt_or_scheduler_fsdp(model, scheduler, path, rank)
# remove all other checkpoint path
if rank == 0:
if not keep_old_checkpoints:
remove_all_folders_except_the_latest_step_count(checkpoint_path, step_count)
def load_model_opt_scheduler_states_fsdp(model, opt, scheduler, checkpoint_path):
last_checkpoint_path, start_step_count = get_last_checkpoint_path_and_last_step_count(checkpoint_path)
# for component, name in [[model, "model"], [opt, "opt"], [scheduler, "scheduler"]]:
path = os.path.join(last_checkpoint_path, "model")
load_state_dict_fsdp(model, path)
rank = torch.cuda.current_device()
path = os.path.join(last_checkpoint_path, "opt")
load_opt_or_scheduler_fsdp(model, opt, path, rank)
path = os.path.join(last_checkpoint_path, "scheduler")
load_opt_or_scheduler_fsdp(model, scheduler, path, rank)
return model, opt, scheduler, start_step_count
def remove_all_folders_except_the_latest_step_count(checkpoint_path, cur_step_count):
# get the last checkpoint path
if os.path.exists(checkpoint_path):
for folder in os.listdir(checkpoint_path):
if int(folder) != cur_step_count:
shutil.rmtree(os.path.join(checkpoint_path, folder))
def get_last_checkpoint_path_and_last_step_count(checkpoint_path):
# get the last checkpoint path
if os.path.exists(checkpoint_path):
last_step_count = max(int(x) for x in os.listdir(checkpoint_path))
last_checkpoint_path = os.path.join(checkpoint_path, str(last_step_count))
return last_checkpoint_path, last_step_count+1
return None, None
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def get_all_existing_loggers():
return logging.Logger.manager.loggerDict.values()
def add_padding_token(tokenizer):
print("attempt to add padding token if no padding token exists")
print("Special tokens before adding padding token: ", tokenizer.special_tokens_map)
if not tokenizer.pad_token:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
print("Special tokens after adding padding token: ", tokenizer.special_tokens_map)
return tokenizer
def make_nonpersistent_buffer_persistent(model):
"""FSDP does not appropriately handle non-persistent buffers when the weights are initialized on gpu
We make the these buffer persistent, so that these buffers are written to disk when saving the model
and can be used to override the incorrect non-persistent buffers when loading the model
"""
for name, module in model.named_modules():
if hasattr(module, "_non_persistent_buffers_set") and len(module._non_persistent_buffers_set) > 0:
print(f"moving non-persistent buffers to persistent buffers for module {name}")
module._persistent_buffers_set = module._non_persistent_buffers_set
module._non_persistent_buffers_set = set()
class LogLevelContext:
def __init__(self, level):
self.level = level
self.original_levels = {}
def __enter__(self):
for logger in get_all_existing_loggers():
if isinstance(logger, logging.Logger):
self.original_levels[logger] = logger.level
logger.setLevel(self.level)
def __exit__(self, exc_type, exc_value, traceback):
for logger, original_level in self.original_levels.items():
logger.setLevel(original_level)

Binary file not shown.

After

Width:  |  Height:  |  Size: 605 KiB

71
MaskedThought/README.md Normal file
View File

@ -0,0 +1,71 @@
# Masked Thought
This repository contains the code of for our paper:<br>
**Masked Thought: Simply Masking Partial Reasoning Steps Can Improve
Mathematical Reasoning Learning of Language Models**<br>
*Changyu Chen, Xiting Wang, Ting-En Lin, Ang Lv, Yuchuan Wu, Xin Gao, Ji-Rong Wen, Rui Yan, Yongbin Li* <br>
Paper: https://arxiv.org/abs/2403.02178 <br>
### 1. Overview
We propose to use a simple regularization method **Masked thought Fine-Tuning (MFT)** for the supervised fine-tuning of mathematical reasoning data.
<div align="center">
<img src="overview.png" width = "550" alt="d" align=center />
</div>
### 2. Main results
<div align="center">
<img src="main_res.png" width = "550" alt="d" align=center />
</div>
### 3. Quick Start
#### 1) Installation
```bash
conda create -n mask python=3.10
conda activate mask
pip install -r requirements.txt
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
```
#### 2) Train
You can start with training GSM8K on Llama-2-7b with the following command, it will take about one hour with 8 NVIDIA
A100 GPUs.
```bash
bash training_scripts/run_llama2_7b_gsm8k_mft.sh
```
#### 3) Evaluation
The evaluation take 1 minute using vllm and 1 single A100 GPU.
```bash
# exp_name is the experiment identifier in your training script.
bash evluation/run_gen_math_greedy_vllm_1.sh ${exp_name}
python evaluation/get_gsm8k_res.py --model ${exp_name}
```
For training and evaluation on other datasets and models, you can refer to ```./training_scripts```, ```./MetaMath``` and ```./MAmmoTH```.
You can also refer to the repos of [MetaMath](https://github.com/nlpxucan/WizardLM/tree/main/WizardMath), [MAmmoTH]() and [RFT](https://github.com/OFA-Sys/gsm8k-ScRel/tree/main) for more details about their evaluation and datasets.
### 4. Train with your own code
If you'd like incorporate MFT into your own code, just add the following codes before feeding `input_ids` to the model. `MASK_INDEX` can be the `token_id` of a new added token`[mask]` or the`[pad]` in the original tokenizer, depending on your preference.
```python
def mask_target_tokens(input_ids, sources_tokenized, mask_probability, MASK_INDEX, tokenizer):
masked_input_ids = copy.deepcopy(input_ids)
for input_id, source_len in zip(masked_input_ids, sources_tokenized["input_ids_lens"]):
for i in range(source_len, len(input_id)):
if random.random() < mask_probability:
input_id[i] = MASK_INDEX
return masked_input_ids
input_ids = mask_target_tokens(input_ids, sources_tokenized, _mask_probability, MASK_INDEX)
```
### 5. Citation
```bibtex
@article{chen2024masked,
title={Masked Thought: Simply Masking Partial Reasoning Steps Can Improve Mathematical Reasoning Learning of Language Models},
author={Chen, Changyu and Wang, Xiting and Lin, Ting-En and Lv, Ang and Wu, Yuchuan and Gao, Xin and Wen, Ji-Rong and Yan, Rui and Li, Yongbin},
journal={arXiv preprint arXiv:2403.02178},
year={2024}
}
```

View File

View File

@ -0,0 +1,97 @@
from dataclasses import dataclass,field
from . import register_argumentclass
from transformers import TrainingArguments
@register_argumentclass("base")
@dataclass
class BaseArguments:
scene: str = field(default="realtime_gen",metadata={"help":"Specific Scene"})
@register_argumentclass("train")
@dataclass
class TrainArguments(TrainingArguments):
#preprocess_header: processor:col0:col1:col2:...:keyname
model: str = field(default="local",metadata={"help":"Either local or model identifier from huggingface.co/models"})
task: str = field(default="",metadata={"help":"Model Task"})
train_dir: str = field(default="../Data/train.txt",metadata={"help":"Training data path"})
train_header: str = field(default="",metadata={"help":"Training data header"})
train_model_header: str = field(default="",metadata={"help":"Training preprocess header"})
eval_dir: str = field(default="../Data/eval.txt",metadata={"help":"validation data path"})
eval_header: str = field(default="",metadata={"help":"Eval data header"})
eval_model_header: str = field(default="",metadata={"help":"Eval preprocess header"})
predict_header: str = field(default="",metadata={"help":"Predict data header"})
predict_model_header: str = field(default="",metadata={"help":"Predict preprocess header"})
result_header: str = field(default="",metadata={"help":"Result output header"})
result_file: str = field(default="predict.txt",metadata={"help":"Result file name"})
previous_dir: str = field(default="./cache", metadata={"help": "previous model path"})
output_dir: str = field(default="./model_output", metadata={"help": "Output data path"})
cache_dir: str = field(default="./cache", metadata={"help": "Path to cache auto models"})
logging_dir: str = field(default="./log_output", metadata={"help": "Path to save logging"})
from_tf: bool = field(default=False, metadata={"help": "load model from tf weight"})
datareader_streaming: bool = field(default=False, metadata={"help": "whether to load data in streaming way"})
dataloader_num_workers: int = field(default=3, metadata={"help": "reader workers"})
outputter_num_workers: int = field(default=None,
metadata={"help": "outputter workers, default (None) to dataloader_num_workers"})
extract_chunk: int = field(default=10, metadata={"help": "param for pool.imap, chunks processes for one time"})
shuffle_num_batch: int = field(default=50, metadata={"help": "N batch for shuffle buffer"})
num_train_epochs: int = field(default=1, metadata={"help": "N epochs to train"})
logging_steps: int = field(default=200, metadata={"help": "default step to logging (print, azureml, tensorboard)"})
save_steps: int = field(default=5000, metadata={"help": "default step to save ckpt, should be same as eval_steps"})
save_total_limit: int = field(default=2, metadata={"help": "save total limit"})
eval_steps: int = field(default=2000000, metadata={"help": "evaluation every N steps"})
evaluation_strategy: str = field(default="steps", metadata={"help": "evaluation strategy: no/steps/epoch"})
load_best_model_at_end: bool = field(default=False, metadata={"help": "load best model at the end for save"})
greater_is_better: bool = field(default=True, metadata={"help": "help to judge best model"})
fp16: bool = field(default=False, metadata={"help": "whether to enable mix-precision"})
remove_unused_columns: bool = field(default=False, metadata={
"help": "Remove columns not required by the model when using an nlp.Dataset."})
disable_tqdm: bool = field(default=True, metadata={"help": "Disable tqdm, print log instead"})
# Train Setting
trainer: str = field(default="common", metadata={"help": "user-defined trainer to select"})
eval_metrics: str = field(default="acc", metadata={"help": "Metrics to eval model with delimiter ,"})
evaluate_during_training: bool = field(default=True, metadata={"help": "Wheter to enable evaluation during training"})
to_cache: str = field(default="", metadata={"help": "To hotfix fields pop issue in huggingface"})
output_type: str = field(default="score", metadata={"help": "Set outputter type"})
eager_predict: bool = field(default=False, metadata={"help": "Eager prediction for debugging"})
dense_gradients: bool = field(default=False, metadata={"help": "Sync dense gradient for efficiency"})
compress_gradients: bool = field(default=False, metadata={"help": "Sync gradient in fp16"})
label_names: str = field(default=None, metadata={"help": "label names for label ids"})
tok_max_length: int = field(default=1024, metadata={"help": "max length for tokenizer"})
tgt_max_length: int = field(default=512, metadata={"help": "max length for target"})
special_tokens: str = field(default="", metadata={"help": "Set special token for tokenizer, split by ,"})
'''
llm
'''
use_peft: bool = field(default=False)
lora_rank: int = field(default=8)
update_mask_rate: bool = field(default=False)
max_mask_rate: float = field(default=0.4)
mask_rate_warmup_ratio: float = field(default=2/3)
pad_front: bool = field(default=True)
model_name: str = field(default='')
load_in_8bit: bool = field(default=False)
load_in_4bit: bool = field(default=False)
instruct_format: bool = field(default=False)
instruct_type: str = field(default='default')
adapter_path: str = field(default='')
lora_no_emb: bool=field(default=False)
use_vllm: bool=field(default=True)
cut_front: bool=field(default=False)
modules_to_save: bool=field(default=False)
resize_at_begin: bool=field(default=True)
include_tokens_per_second: bool = field(
default=False,
metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."},
)
ddp_find_unused_parameters: bool=field(default=False)
cut_src: bool=field(default=False)
not_save: bool=field(default=False)
exp_name: str = field(default="exp", metadata={"help": "name for model saving path"})
print_every: int = field(default=1000, metadata={"help": "print log every real steps"})

View File

@ -0,0 +1,56 @@
from dataclasses import dataclass,field
from torch.utils import data
from transformers import (
AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM,
AutoModelForPreTraining,
)
from . import register_argumentclass
from transformers.models.auto import configuration_auto
from transformers.models.auto import tokenization_auto
from transformers.models.bert.tokenization_bert import BertTokenizer
class BaseArguments:
def process(self):
#return {"smile":"^_^"}
return {}
@register_argumentclass("model_seq2seq")
@dataclass
class Seq2SeqArguments(BaseArguments):
#Generation
auto_model = AutoModelForSeq2SeqLM
mask_input: bool = field(default=True)
mask_rate: float = field(default=0)
neftune_alpha: float=field(default=0)
token_dropout: float=field(default=0)
only_drop_target: bool=field(default=False)
neft_all_token: bool=field(default=False)
drop_attention_mask: bool = field(default=False)
token_noise: bool=field(default=False)
token_noise_hard: bool=field(default=False)
token_noise_sample: bool=field(default=False)
token_noise_random: bool=field(default=False)
mixup: bool=field(default=False)
replace: bool=field(default=False)
replace_rate: float=field(default=0)
temperature: float = field(default=1)
not_mask_source: bool=field(default=True)
not_mask_tgt: bool=field(default=False)

View File

@ -0,0 +1,21 @@
import importlib
import os
register_argumentclass_dict = dict()
def register_argumentclass(name):
def register_argumentclass_class(argumentclass_class):
register_argumentclass_dict[name] = argumentclass_class
return argumentclass_class
return register_argumentclass_class
__all__ = []
argumentclass_dir = os.path.dirname(__file__)
for f in os.listdir(argumentclass_dir):
fpath = os.path.join(argumentclass_dir,f)
if not f.startswith('.') and (f.endswith('.py')):
fname = f[:f.find('.py')]
module = importlib.import_module(f'.{fname}','config.ArgumentClass')
for key,cls in register_argumentclass_dict.items():
globals()[key] = cls

View File

@ -0,0 +1,16 @@
--train_header data_src,data_gt,data_kd
--train_model_header llama_s2s_tokenize:data_src:data_gt:input,data_src,data_gt,data_kd
--eval_header data_src,data_gt,data_kd
--eval_model_header llama_s2s_tokenize:data_src:data_gt:input,data_src,data_gt,data_kd
--train_dir static_data/cnndm/cnndm_train.tsv \
--eval_dir static_data/cnndm/cnndm_valid.tsv \
--eval_metrics bleu
--predict_header data_src,data_gt
--predict_model_header llama_s2s_tokenize:data_src:input,data_src,data_gt
--result_header data_src,data_gt
--output_type llama_generation
--task seq2seq
--tok_max_length 1024
--save_steps 10000000
--model decapoda-research/llama-7b-hf
bo

View File

@ -0,0 +1,57 @@
import sys
import inspect
from transformers import logging
logger = logging.get_logger(__name__)
logger.setLevel('INFO')
def replace(target_obj, is_allowed=True):
""" Copy from fastseq: https://github.com/microsoft/fastseq
A decorator to replace the specified obj.
`target_obj` can be a class or a function.
Example:
```python
class A:
def f(self):
print('class A')
@replace(A)
class B:
def f(self):
print('class B')
```
Args:
target_obj (class/func/method): a class, method, or function to be
replaced.
Returns:
A decorator function to replace the input object.
"""
def decorator(new_obj):
if not is_allowed:
return target_obj
# if target_obj in OPTIMIZED_CLASSES:
# logger.warning("{} has been optimized again.".format(target_obj))
setattr(new_obj, '__replaced_class__', target_obj)
# OPTIMIZED_CLASSES[target_obj] = new_obj
for k, v in list(sys.modules.items()):
if (target_obj.__name__ in v.__dict__
and v.__dict__[target_obj.__name__] is target_obj):
delattr(sys.modules[k], target_obj.__name__)
setattr(sys.modules[k], target_obj.__name__, new_obj)
logger.warning("In module {}, {} is replaced by {}".format(
k, target_obj, new_obj))
# replace target_obj if it is used as the base classes.
for key in list(v.__dict__.keys()):
if (inspect.isclass(v.__dict__[key]) and
v.__dict__[key] != new_obj and
target_obj in v.__dict__[key].__bases__):
idx = v.__dict__[key].__bases__.index(target_obj)
bases = list(v.__dict__[key].__bases__)
bases[idx] = new_obj
v.__dict__[key].__bases__ = tuple(bases)
logger.warning(
"In module {}, the base class of {} is replaced by {}"
.format(k, v.__dict__[key], new_obj))
return new_obj
return decorator

View File

@ -0,0 +1,64 @@
import dataclasses
import functools
import os
from transformers import (
HfArgumentParser,
logging,
AutoConfig
)
import config.ArgumentClass as ArgumentClass
logger = logging.get_logger(__name__)
logger.setLevel('INFO')
def identifier(cls):
return cls.__class__.__name__.split(".")[-1][:-6]
def print_arg(args):
for k,v in args.__dict__.items():
logger.info(k + ": " + str(v))
def print_args(argset):
for setname,singlearg in argset.items():
logger.info("<<<<<<<<"+setname+">>>>>>>>")
print_arg(singlearg)
def parse_args():
allargs = {}
folder = os.path.dirname(os.path.abspath(__file__))
#01. Base Configs
parser = HfArgumentParser(ArgumentClass.base)#(getattr(ArgumentClass,'base')))
allargs["base_args"],remain = parser.parse_args_into_dataclasses(return_remaining_strings=True)
#02. Data & Train Configs
parser = HfArgumentParser(ArgumentClass.train)
allargs["train_args"],remain_train = parser.parse_args_into_dataclasses(remain,args_filename= folder + "/SceneConfigs/"+allargs["base_args"].scene,return_remaining_strings=True)
#03. Model Configs
model_identifier = allargs["train_args"].model if not allargs["train_args"].model.startswith("local") else allargs["train_args"].previous_dir
try:
parser = HfArgumentParser((getattr(ArgumentClass,'model_' + allargs["train_args"].task)))
except:
logger.error("Task " + allargs["train_args"].task + " not registered!")
exit(0)
allargs["task_args"],remain_task = parser.parse_args_into_dataclasses(remain,args_filename=folder + "/SceneConfigs/"+allargs["base_args"].scene,return_remaining_strings=True)
task_dict = dataclasses.asdict(allargs["task_args"])
task_dict.update(allargs["task_args"].process())
remain = [k for k in remain_task if k in remain_train and k != "http" and k[:2] == "--"]
if remain:
logger.error("unused command line args:" + str(remain))
exit(1)
allargs["model_args"],remain = AutoConfig.from_pretrained(model_identifier,return_unused_kwargs=True,_name_or_path=model_identifier,**task_dict)
for k in remain:
setattr(allargs["model_args"],k,remain[k])
if remain:
logger.warning("unused args:" + str(remain))
if "DLTS_JOB_ID" in os.environ:
log_dir = os.path.expanduser('~/tensorboard/{}/logs/'.format(os.environ['DLTS_JOB_ID']))
allargs["train_args"].logging_dir = log_dir
logger.info("Replace Tensorboard dir to " + log_dir)
print_args(allargs)
if not os.path.exists(allargs["train_args"].logging_dir) and allargs["train_args"].local_rank <= 0:
os.makedirs(allargs["train_args"].logging_dir, exist_ok=True)
if not os.path.exists(allargs["train_args"].output_dir) and allargs["train_args"].local_rank <= 0:
os.makedirs(allargs["train_args"].output_dir, exist_ok=True)
return allargs["base_args"],allargs["train_args"],allargs["model_args"],allargs["task_args"]

View File

@ -0,0 +1,29 @@
from . import register_processor
from transformers import AutoTokenizer
import numpy
class BaseProcessor:
def __init__(self, cfg, model, **kwargs):
self.out_key = []
self.padding_values = []
tokenizer = kwargs.pop("tokenizer", None)
self.fn = tokenizer if tokenizer else AutoTokenizer.from_pretrained(model, cache_dir=cfg.cache_dir)
def property(self):
return {
"values":dict([key,value] for key,value in zip(self.out_key,self.padding_values)) if self.padding_values else {}
}
@register_processor("basic")
class SelectColumn(BaseProcessor):
def __init__(self,idx,out_name,model=None,cfg = None,task_cfg=None):
self.idx = idx
self.out_key = [out_name]
self.out_name = out_name
self.padding_values = None
def process(self,columns):
try:
return {self.out_name:columns[self.idx]}
except IndexError as e:
print(e)
print('select', columns)
return {self.out_name:""}

View File

@ -0,0 +1,134 @@
from numpy.core.records import array
from . import register_processor
from transformers import AutoTokenizer
import numpy
from .BasicProcessor import BaseProcessor
import random
from transformers import logging
logger = logging.get_logger(__name__)
logger.setLevel('INFO')
from typing import List
import torch
import numpy as np
@register_processor("s2s_tokenize")
class S2STokenize(BaseProcessor):
def __init__(self,idx,out_name,model,cfg=None,task_cfg=None,**kwargs):
super().__init__(cfg, model, **kwargs)
self.idx = idx
self.max_length = getattr(cfg,"tok_max_length",512) if cfg else 512
self.max_target_length = getattr(cfg,"max_length", 128) if cfg else 128
if len(self.idx) == 1:
self.out_key = ["input_ids","attention_mask"]
self.padding_values = [0,0]
else:
self.out_key = ["input_ids","attention_mask","labels"]
self.padding_values = [0,0,-100]
if cfg.local_rank <= 0:
self.fn.save_pretrained(cfg.output_dir)
def process(self,columns):
if len(self.idx) == 1:
try:
res = self.fn(columns[self.idx[0]],return_tensors='np',max_length=self.max_length,truncation=True)
except:
logger.warning("Data Error")
res = self.fn('',return_tensors='np',max_length=self.max_length,truncation=True)
else:
try:
res = self.fn(columns[self.idx[0]],return_tensors='np',max_length=self.max_length,truncation=True)
# with self.fn.as_target_tokenizer():
labels = self.fn(text_target=columns[self.idx[1]],return_tensors='np',max_length=self.max_target_length,truncation=True)
res["labels"] = labels["input_ids"]
except:
res = self.fn('',return_tensors='np',max_length=self.max_length,truncation=True)
# with self.fn.as_target_tokenizer():
labels = self.fn(text_target='nonenone',return_tensors='np',max_length=self.max_target_length,truncation=True)
res["labels"] = labels["input_ids"]
logger.warning("Data Error" + str(columns))
return dict([(k,numpy.squeeze(v,0)) for k,v in res.items()])
@register_processor("llama_s2s_tokenize")
class llamaS2STokenize(BaseProcessor):
def __init__(self, idx, out_name, model, cfg=None, task_cfg=None, **kwargs):
super().__init__(cfg, model, **kwargs)
self.idx = idx
self.max_length = getattr(cfg, "tok_max_length", 512) if cfg else 512
self.max_target_length = getattr(cfg, "max_length", 128) if cfg else 128
self.out_key = ["input_ids", "attention_mask", "labels", "label_mask"]
self.padding_values = [0, 0, -100, 0]
if cfg.local_rank <= 0:
self.fn.save_pretrained(cfg.output_dir)
def process(self, columns):
if len(self.idx) == 1:
try:
input_text = columns[self.idx[0]]
# input_text = input_text.split('<#SEP#>')
input_text = " ".join(input_text)
input_text = input_text.replace("<|prompter|>", "\n\nHuman: ").replace("<|assistant|>",
"\n\nAssistant: ")
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
input_text = PROMPT_DICT['prompt_input'].format(instruction='Give a response as the assistant with the input conversation history', input=input_text)
res = self.fn(input_text, return_tensors='np',max_length=self.max_length-128,truncation=True)
except:
logger.warning("Data Error")
res = self.fn('',return_tensors='np',max_length=self.max_length-128,truncation=True)
else:
try:
input_text = columns[self.idx[0]]
label_text = columns[self.idx[1]]
except IndexError as e:
print(e)
print('tokenizeP', columns)
input_text = "none none"
label_text = "none none"
# input_text = input_text.split('<#SEP#>')
input_text = "".join(input_text)
input_text = input_text.replace("<|prompter|>", "\n\nHuman: ").replace("<|assistant|>", "\n\nAssistant: ").rstrip()
p = input_text + " " + label_text
# print(p)
# assert 1==0
# p = p.replace("<|prompter|>", "\n\nHuman: ").replace("<|assistant|>", "\n\nAssistant: ").rstrip()
res = self.fn(p, return_tensors='np', max_length=self.max_length-128, truncation=True)
labels = self.fn(text_target=label_text, return_tensors='np', max_length=self.max_target_length,
truncation=True)
res["labels"] = labels["input_ids"]
labels_len = res["labels"].shape[1] # 1 (bos) + len_2
seq_len = res["attention_mask"].shape[1] # 1 (bos) + len_1 + len_2
# 1 + len_1
label_mask = [[0 if i < 1 + seq_len-labels_len else 1 for i in range(seq_len)]]
res["label_mask"] = numpy.array(label_mask)
return dict([(k, numpy.squeeze(v, 0)) for k, v in res.items()])

View File

@ -0,0 +1,22 @@
import importlib
import os
registered_processor = dict()
def register_processor(name):
def wrapper(processor_cls):
registered_processor[name] = processor_cls
return processor_cls
return wrapper
processor_dir = os.path.dirname(__file__)
for f in os.listdir(processor_dir):
fpath = os.path.join(processor_dir,f)
if not f.startswith('.') and (f.endswith('.py')):
fname = f[:f.find('.py')]
module = importlib.import_module(f'.{fname}','data.Processor')
for key,cls in registered_processor.items():
globals()[key] = cls

View File

@ -0,0 +1,252 @@
import os
import sys
import torch
import datasets
from datasets.packaged_modules.text.text import Text
from config.decorator import replace
#from datasets import load_dataset
datasets.logging.set_verbosity(datasets.logging.ERROR)
from dataclasses import dataclass
import data.Processor as Processor
from torch.nn.utils.rnn import pad_sequence
import pyarrow as pa
from datasets import Features, Sequence, Value
from transformers.trainer_pt_utils import IterableDatasetShard
sys.path.append(".")
import json
class IterDataSet(torch.utils.data.IterableDataset):
def __init__(self,files,mode):
super().__init__()
self.files = files
self.mode = mode
self.rowcnt = self.estimate_samples(self.files)
def __len__(self):
return self.rowcnt
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
num_worker = 1
worker_id = 0
else:
num_worker = worker_info.num_workers
worker_id = worker_info.id
cnt = 0
for f in self.files:
for line in open(f,encoding='utf-8'):
if cnt % num_worker == worker_id:
yield {self.mode: line.rstrip("\n")}
def estimate_samples(self,filelist):
cnt = 0
for file in filelist:
cnt += self.row_count(file)
return cnt
def row_count(self,filename):
f = open(filename, 'rb')
lines = 0
buf_size = 1024 * 1024
read_f = f.raw.read
buf = read_f(buf_size)
while buf:
lines += buf.count(b'\n')
buf = read_f(buf_size)
return lines
class InputPipe:
"""
base class to store and process dataset
"""
def __init__(self, model, train_args, task_args, mode, auto_tokenizer):
self.input_header = getattr(train_args, mode + "_header")
self.model_header = getattr(train_args, mode + "_model_header")
self.train_args = train_args
self.task_args = task_args
self.mode = mode
self.model = model
self.initialized = True
self.workers = self.train_args.dataloader_num_workers if self.train_args.dataloader_num_workers else 1
mode_identifier = "train" if mode == "train" else "eval"
self.input_files = self.get_filenames(getattr(train_args,mode_identifier + "_dir"))
print(self.input_files)
if train_args.datareader_streaming:
self.ds = IterDataSet(self.input_files,mode_identifier)
else: #Default
# import `DataBuilder` (with `BuilderConfig`) from 'text' module to load dataset
self.ds = datasets.load_dataset('text',data_files=self.input_files,encoding='utf-8',features=Features({mode_identifier:Value("string")}))["train"]
self.feature_extractor = FeatureExtractor(self.input_header, self.model_header, self.model, self.train_args,
self.task_args, auto_tokenizer)
def get_dataset(self):
return self.ds
def get_filenames(self,name):
if os.path.isdir(name):
return [name + "/" + i for i in os.listdir(name)]
else:
return [name]
class FeatureExtractor():
"""
Feature extractor to process features from dataloader in train/eval mode
Args :
input_header (:obj:`str`):
of form ``train/eval_header``
model_header (:obj:`str`):
of form ``tokenizer:header1:header2:data_type``
model (:obj:`PretrainedModel`)
loaded model with model config
train_args (:class:`~transformers.TrainingArguments`, `optional`):
general train arguments
tast_args (:class:`~transformers.TrainingArguments`, `optional`):
specific task arguments
"""
def __init__(self, input_header, model_header, model, train_args, task_args, auto_tokenizer):
self.input_header = dict((h,i) for i,h in enumerate(input_header.split(",")))
self.model_header = model_header.split(",")
self.model = model
self.train_args = train_args
self.task_args = task_args
self.out_types,self.out_shapes,self.out_paddings,self.padding_values= {},{},{},{}
self.processors = dict()
self.tokenizer = auto_tokenizer
self.init_proc()
self.set_property()
def __call__(self,x):
return self.preprocess(x)
def init_proc(self):
""" initialize line processor, e.g, tokenizer processor """
for mh in self.model_header:
cols = mh.split(":")
out_name = cols[-1]
if out_name in self.processors:
raise ValueError(f"processor for out_name={out_name} is already registered with Processor {self.processors[out_name]}")
if len(cols) == 1:
cls = getattr(Processor,'basic')
self.processors[out_name] = cls(self.input_header[out_name],out_name)
elif len(cols) == 2:
cls = getattr(Processor,cols[0])
self.processors[out_name] = cls(self.input_header[out_name],out_name)
else:
# out_name : input / doc
cls = getattr(Processor, cols[0])
feat_idx = [self.input_header[x] for x in cols[1:-1]]
self.processors[out_name] = cls(idx=feat_idx, out_name=out_name, model=self.model, cfg=self.train_args,
task_cfg=self.task_args, tokenizer=self.tokenizer)
def preprocess(self,line):
""" Process input columns in batch form using e.g., tokenizer """
if isinstance(line, dict):
line = line['text']
elif not isinstance(line, str):
line = line.decode('utf-8')
line = json.loads(line)
column = [line['source'][0], line['target'][0], line['target'][0]]
if len(column) != 3:
print(line)
res = dict()
for n in self.processors:
tmp = self.processors[n].process(column)
if tmp == None:
return None
res.update(tmp)
return res
def set_property(self):
"""
set padding values as map(`dict`) for e.g, "input_ids","attention_mask","labels"
"""
for n in self.processors:
p = self.processors[n].property()
self.padding_values.update(p["values"])
@dataclass
class CustomizeCollator:
"""
Collator of train/eval dataloader, process and tensorize features if needed
"""
train_dataset:InputPipe
eval_dataset:InputPipe
@property
def feature_extractor(self):
fn = {}
fn["train"] = self.train_dataset.feature_extractor if self.train_dataset else None
fn["eval"] = self.eval_dataset.feature_extractor if self.eval_dataset else None
return fn
def __call__(self, features):
raw_features = []
res = {}
mode = None
# get processed lines using tokenizer (feature_extractor)
# features struct : {'train/eval': raw line}
for sample in features:
for key,line in sample.items():
if mode == None:
mode = key
raw_features.append(self.feature_extractor[key](line))
# convert batch data (run in training) with padding data
for fn in raw_features[0]:
# todo
if fn in self.feature_extractor[mode].padding_values:
if (isinstance(raw_features[0][fn], list) and isinstance(raw_features[0][fn][0], list)) or (hasattr(raw_features[0][fn], "shape") and len(raw_features[0][fn].shape)>=2):
fv = [torch.tensor(f[fn][i]) for f in raw_features for i in range(len(f[fn]))]
else:
fv = [torch.tensor(f[fn]) for f in raw_features]
if self.feature_extractor[mode].padding_values[fn] == None:
res[fn] = torch.stack(fv)
else:
res[fn] = pad_sequence(fv,batch_first=True,padding_value=self.feature_extractor[mode].padding_values[fn])
else:
res[fn] = [f[fn] for f in raw_features]
return res
def input_builder(model, train_args, task_args=None, tokenizer=None):
train_pipe = InputPipe(model, train_args, task_args, 'train', tokenizer) if train_args.do_train else None
eval_pipe = InputPipe(model, train_args, task_args, 'eval', tokenizer) if train_args.do_eval else None
predict_pipe = InputPipe(model, train_args, task_args, 'predict', tokenizer) if train_args.do_predict else None
return train_pipe, eval_pipe, predict_pipe
@replace(Text)
class Text(Text):
def _generate_tables(self, files):
schema = pa.schema(self.config.features.type if self.config.features is not None else {"text": pa.string()})
for file_idx, file in enumerate(files):
batch_idx = 0
with open(file, "r", encoding=self.config.encoding) as f:
while True:
batch = f.read(self.config.chunksize)
if not batch:
break
batch += f.readline() # finish current line
batch = batch.strip("\n").split("\n")
pa_table = pa.Table.from_arrays([pa.array(batch)], schema=schema)
# Uncomment for debugging (will print the Arrow table size and elements)
yield (file_idx, batch_idx), pa_table
batch_idx += 1
@replace(IterableDatasetShard)
class IterableDatasetShard(IterableDatasetShard):
def __init__(self,*args,**kwargs):
super().__init__(*args,**kwargs)
self.num_examples = len(self.dataset)
def __len__(self):
return (len(self.dataset) - 1 - self.process_index) // self.num_processes + 1
if __name__ == "__main__":
from config.parse_args import *
base_args,train_args,model_args,task_args = parse_args()
test = InputPipe("bert-base-uncased",train_args,'eval')
ds = test.get_dataset()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,13 @@
from transformers import AutoTokenizer
import torch
def prepare_tokenizer(model, cache_dir, **kwargs):
special_tokens = kwargs.pop("special_tokens", None)
if special_tokens:
special_tokens = special_tokens.split(",")
if model == 'Yale-LILY/brio-xsum-cased':
model = 'google/pegasus-xsum'
if model == 'Yale-LILY/brio-cnndm-uncased':
model = 'facebook/bart-large-cnn'
auto_tokenizer = AutoTokenizer.from_pretrained(model, cache_dir=cache_dir)
auto_tokenizer.add_tokens(special_tokens)
return auto_tokenizer

View File

@ -0,0 +1,133 @@
# import models
import torch
import transformers
from transformers import logging
from config.parse_args import *
from data.data_reader import *
from transformers import Trainer
logging.set_verbosity_info()
logging.enable_explicit_format()
import logging as local_logging
logger = logging.get_logger(__name__)
logger.setLevel('INFO')
local_logging.basicConfig(format="[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s",level=logging.INFO)
from peft import get_peft_config, get_peft_model, PeftModel, LoraConfig, TaskType, prepare_model_for_int8_training, prepare_model_for_kbit_training
from data.tokenizer_utils import prepare_tokenizer
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED, as_completed, ALL_COMPLETED
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
from vllm import LLM, SamplingParams
from tqdm import tqdm
import torch.distributed as dist
def ddp_setup(rank: int, world_size: int):
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
rank = os.environ["LOCAL_RANK"]
rank = int(rank)
print('rank', rank)
num_gpus = torch.cuda.device_count()
ddp_setup(rank, num_gpus)
n = 1
sampling_params = SamplingParams(temperature=0, max_tokens=512, n=n)
# import time
base_args,train_args,model_args,task_args = parse_args()
model_name = train_args.model_name
auto_tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name)
print(auto_tokenizer.pad_token)
print(auto_tokenizer.bos_token)
print(auto_tokenizer.eos_token)
print(auto_tokenizer.unk_token)
print(auto_tokenizer.pad_token_id)
print(auto_tokenizer.eos_token_id)
special_tokens_dict = dict()
if auto_tokenizer.pad_token is None:
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
if auto_tokenizer.eos_token is None:
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
if auto_tokenizer.bos_token is None:
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
if auto_tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
auto_tokenizer.add_special_tokens(special_tokens_dict)
auto_tokenizer.add_tokens(["<mask>"])
train_input, eval_input, predict_input = input_builder(model_args._name_or_path, train_args, task_args,
auto_tokenizer)
kwargs = {}
model = LLM(model_name, tokenizer_mode='slow')
print('finish build LLM')
json_data = open('data/gsm8k_test.json', 'r').readlines()
json_data = [json.loads(e.strip()) for e in json_data]
all_res = []
steps = 0
all_input = []
all_ori_input = []
for i, k in enumerate(json_data):
input_text = k['source'][0]
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
input_text = PROMPT_DICT['prompt_no_input'].format(instruction=input_text)
all_input.append(input_text)
all_ori_input.append(k)
if train_args.use_vllm:
gen_res = model.generate(all_input, sampling_params)
assert len(gen_res) == len(all_ori_input)
for i in range(len(gen_res)):
gen_res_list = [gen_res[i].outputs[j].text for j in range(n)]
all_ori_input[i]['kd_data'] = gen_res_list
all_res.append(json.dumps(all_ori_input[i]))
else:
gen_res = []
batch_size = 25
batches = [all_input[i:i + batch_size] for i in range(0, len(all_input), batch_size)]
model.eval()
for i in tqdm(range(len(batches))):
batch = batches[i]
auto_tokenizer.padding_side = "left"
input_ids = auto_tokenizer(batch, return_tensors="pt", padding=True, truncation=True).input_ids
print(input_ids)
with torch.no_grad():
output_ids = model.generate(input_ids, max_new_tokens=256, num_return_sequences=1, num_beams=1, bos_token_id=auto_tokenizer.bos_token_id, eos_token_id=auto_tokenizer.eos_token_id, pad_token_id=auto_tokenizer.eos_token_id)
for j, output_id in enumerate(output_ids):
gen_res.append(auto_tokenizer.decode(output_id, skip_special_tokens=True))
print(gen_res[-1])
assert len(gen_res) == len(all_ori_input)
for i in range(len(gen_res)):
gen_res_list = [gen_res[i]]
all_ori_input[i]['kd_data'] = gen_res_list
all_res.append(json.dumps(all_ori_input[i]))
train_args.model_name = train_args.model_name.replace('/','')
output_file = open(train_args.output_dir + '/' + 'predict_greedy_{}.json'.format(train_args.model_name), 'w')
output_file.write('\n'.join(all_res))

View File

@ -0,0 +1,74 @@
import nltk
import argparse
parser = argparse.ArgumentParser(description="")
parser.add_argument('--type', type=str, default='greedy')
parser.add_argument('--model', type=str)
args = parser.parse_args()
args.model = args.model.replace('/','')
x = []
tmp = open('eval_output/llama/math/predict_{}_{}.json{}'.format(args.type, args.model)).readlines()
x += tmp
import json
import re
sc_correct = 0
correct = 0
tot = 0
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"
def get_bleu(hyp, ref):
hyp = hyp.strip()
ref = ref.strip()
return nltk.translate.bleu_score.sentence_bleu([ref], hyp)
def extract_answer(completion):
if completion.find('\u0000') >= 0:
completion = completion[0:completion.find('\u0000')]
match = ANS_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
try:
float(match_str)
except BaseException:
return INVALID_ANS
return match_str
else:
return INVALID_ANS
invalid = 0
res = []
all_length = []
all_bleus = []
length_accuracy_data = []
for t in x:
json_data = json.loads(t)
suffix = json_data['suffix'][0]
gold = extract_answer(suffix)
flag = False
for pred in json_data['kd_data']:
ans = extract_answer(pred)
length_of_statement = len(suffix.split())
if ans == INVALID_ANS:
invalid += 1
json_data['judge'] = 'invalid'
elif ans != INVALID_ANS and abs(float(ans) - float(gold)) < 1e-4:
correct += 1
json_data['judge'] = 'true'
else:
json_data['judge'] = 'false'
all_length.append(length_of_statement)
length_accuracy_data.append((length_of_statement, json_data['judge']))
all_bleus.append(get_bleu(ans, gold))
res.append(json.dumps(json_data))
out = open('eval_output/llama/math/predict_{}_{}.json'.format(args.type, args.model), 'w')
out.write('\n'.join(res))
print(args.model)
print("invalid", invalid/len(x))
print('acc', correct/len(x))
print('bleu', sum(all_bleus)/len(x))
print('length', sum(all_length)/len(x))

View File

@ -0,0 +1,21 @@
#!/bin/bash
set -e
output_dir=eval_output/llama/math
mkdir -p ${output_dir}
checkpoint=$1
CUDA_VISIBLE_DEVICES=0 torchrun --nnodes 1 --nproc_per_node=1 --master_port 61234 gen_math_greedy.py \
--scene llama_generation \
--train_dir data/gsm8k_train.json \
--eval_dir data/gsm8k_train.json \
--previous_dir . \
--output_dir ${output_dir}/ \
--do_predict --do_eval \
--per_device_eval_batch_size 1 \
--num_return_sequences 1 \
--report_to none \
--max_length 768 \
--tok_max_length 256 \
--model ${checkpoint} \
--model_name ${checkpoint} \
--use_vllm True

228
MaskedThought/main.py Normal file
View File

@ -0,0 +1,228 @@
from models import modeling_llama, mask_policy_utils
import torch
import transformers
from transformers import logging
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from config.parse_args import *
from data.data_reader import *
import torch.distributed as dist
import trainer.Trainer as Trainer
logging.set_verbosity_info()
import torch.nn as nn
import numpy as np
import random
import math
logging.enable_explicit_format()
import logging as local_logging
logger = logging.get_logger(__name__)
logger.setLevel('INFO')
local_logging.basicConfig(format="[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s",level=logging.INFO)
from transformers import (
BitsAndBytesConfig,
)
try:
from peft import get_peft_config, get_peft_model, PeftModel, LoraConfig, TaskType, prepare_model_for_int8_training, prepare_model_for_kbit_training
except:
pass
from data.tokenizer_utils import prepare_tokenizer
from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED, as_completed, ALL_COMPLETED
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
# from vllm import LLM
# import time
import os
local_rank = int(os.environ['LOCAL_RANK'])
def setup_seed(seed=42):
torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark=False
torch.backends.cudnn.deterministic=True
os.environ['PYTHONHASHSEED'] = str(seed)
# set_seed(args.seed)
def main():
base_args,train_args,model_args,task_args = parse_args()
setup_seed(train_args.seed)
model_name = train_args.model_name
auto_tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name)
print(auto_tokenizer.unk_token)
print(auto_tokenizer.unk_token_id)
# < unk >
# 0
print(auto_tokenizer.eos_token)
print(auto_tokenizer.eos_token_id)
# < / s >
# 2
print(auto_tokenizer.bos_token)
print(auto_tokenizer.bos_token_id)
# < s >
# 1
print(auto_tokenizer.pad_token)
print(auto_tokenizer.pad_token_id)
# [PAD]
# 32000
special_tokens_dict = dict()
if auto_tokenizer.pad_token is None:
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
if auto_tokenizer.eos_token is None:
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
if auto_tokenizer.bos_token is None:
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
if auto_tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
auto_tokenizer.unk_token = "<unk>"
auto_tokenizer.bos_token = "<s>"
auto_tokenizer.eos_token = "</s>"
auto_tokenizer.add_special_tokens(special_tokens_dict)
auto_tokenizer.add_tokens(["<mask>"])
def resize(model):
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
old_num_tokens = input_embeddings.shape[0]
num_new_tokens = len(auto_tokenizer) - old_num_tokens
print('num_new_tokens', num_new_tokens)
if num_new_tokens != 0:
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
model.resize_token_embeddings(len(auto_tokenizer))
model.get_input_embeddings().weight.data[-num_new_tokens:] = input_embeddings_avg
model.get_output_embeddings().weight.data[-num_new_tokens:] = output_embeddings_avg
train_input, eval_input, predict_input = input_builder(model_args._name_or_path, train_args, task_args,
auto_tokenizer)
auto_model = task_args.auto_model if hasattr(task_args,'auto_model') else getattr(models,train_args.task)[identifier(model_args)]
kwargs = {}
if hasattr(auto_model,'set_cfg'):
kwargs["customize_cfg"] = task_args
kwargs["train_cfg"] = train_args
if train_args.load_in_8bit:
compute_dtype = (torch.float16 if train_args.fp16 else (torch.bfloat16 if train_args.bf16 else torch.float32))
max_memory = {i: '30000MB' for i in range(torch.cuda.device_count())}
model = AutoModelForCausalLM.from_pretrained(
model_name,
config=model_args,
load_in_8bit=True,
max_memory=max_memory,
quantization_config=BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
# bnb_4bit_use_double_quant=args.double_quant,
# bnb_4bit_quant_type=args.quant_type,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
),
torch_dtype=(torch.float32 if train_args.fp16 else (torch.bfloat16 if train_args.bf16 else torch.float32)),
)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
elif train_args.load_in_4bit:
compute_dtype = (torch.float16 if train_args.fp16 else (torch.bfloat16 if train_args.bf16 else torch.float32))
max_memory = {i: '30000MB' for i in range(torch.cuda.device_count())}
model = AutoModelForCausalLM.from_pretrained(
model_name,
config=model_args,
load_in_4bit=True,
max_memory=max_memory,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
# bnb_4bit_use_double_quant=args.double_quant,
# bnb_4bit_quant_type=args.quant_type,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
),
torch_dtype=(torch.float32 if train_args.fp16 else (torch.bfloat16 if train_args.bf16 else torch.float32)),
)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, config=model_args, low_cpu_mem_usage=False)
if train_args.resize_at_begin:
resize(model)
def find_all_linear_names( model):
print(model.config)
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
print('lora_module_names', lora_module_names)
return list(lora_module_names)
if train_args.use_peft:
target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'down_proj', 'up_proj', 'lm_head', 'embed_tokens']
if train_args.lora_no_emb:
target_modules.remove('lm_head')
target_modules.remove('embed_tokens')
modules_to_save = []
if train_args.modules_to_save:
modules_to_save = ['lm_head', 'embed_tokens']
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=False,
target_modules=target_modules,
modules_to_save=modules_to_save,
r=train_args.lora_rank, lora_alpha=32, lora_dropout=0.05, bias="none",
# modules_to_save=['lm_head']
)
if train_args.adapter_path != "":
model = PeftModel.from_pretrained(model, train_args.adapter_path, is_trainable=True)
else:
model.enable_input_require_grads()
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
trainer = getattr(Trainer, train_args.trainer)(
model=model,
args = train_args,
model_args = model_args,
train_dataset = train_input,
eval_dataset = eval_input if not train_args.do_predict else predict_input,
task_args = task_args,
auto_tokenizer=auto_tokenizer,
)
if train_args.do_train:
trainer.train()
if train_args.do_predict:
trainer.predict()
if __name__ == "__main__":
logger.info("checking GPU")
if not torch.cuda.is_available():
logger.warning("torch.cuda.is_available() Fail")
else:
logger.info("torch.cuda.is_available() Succeed")
main()

BIN
MaskedThought/main_res.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 270 KiB

View File

@ -0,0 +1,24 @@
import collections
import os
import importlib
registered_model = collections.defaultdict(collections.defaultdict)
def register_model(task,name):
def wrapper(cls):
registered_model[task][name] = cls
return cls
return wrapper
model_dir = os.path.dirname(__file__)
for f in os.listdir(model_dir):
fpath = os.path.join(model_dir,f)
if not f.startswith('.') and (f.endswith('.py')):
fname = f[:f.find('.py')]
try:
module = importlib.import_module(f'.{fname}','models')
except:
pass
for key,value in registered_model.items():
globals()[key] = value

View File

@ -0,0 +1,273 @@
import random
import numpy as np
from numpy import random as np_rand
import torch
import torch.nn as nn
from transformers.modeling_outputs import (
Seq2SeqLMOutput,
CausalLMOutputWithPast
)
class MaskPolicy():
def __init__(self, config, bpe_prefix):
self.config = config
self.bpe_prefix = bpe_prefix
def get_gpt_masked_input(self, labels, target_mask, tokenizer, mask_for='none', types=None):
self.mask_id = tokenizer.encode('<mask>', add_special_tokens=False)[0]
mask_this = True
def pr(ids):
print(tokenizer.convert_ids_to_tokens(ids[0]))
bs = labels.shape[0]
mask_labels = labels.detach().clone() # bs, seq
mask_labels_np = mask_labels.cpu().numpy()
tmp_tks = self.get_tks(mask_labels_np, tokenizer)
non_zero_labels = ~(
labels.data.eq(tokenizer.pad_token_id)) # 0 pad 2 eos -100 pad
# the count of non-special token
non_zero_sum_tensor = non_zero_labels.sum(-1) # bs
non_zero_sum = non_zero_sum_tensor.detach().cpu().numpy().tolist()
masked_pos_shift = torch.zeros_like(mask_labels) # bs, seq
masked_pos_non_shift = torch.zeros_like(mask_labels) # bs, seq
labels_numpy = mask_labels_np
labels_numpy = labels_numpy.tolist()
if target_mask is not None:
target_mask_np = target_mask.cpu().numpy()
if types is not None:
assert bs == len(types)
for i in range(bs):
cand_pos = []
all_pos = []
k = 0
mask_rate = self.config.mask_rate
while k < len(target_mask_np[i]):
if self.config.not_mask_tgt:
if int(target_mask_np[i][k]) == 1:
k += 1
continue
if self.config.not_mask_source:
if int(target_mask_np[i][k]) == 0:
k += 1
continue
if mask_rate == 1:
cand_pos.append(k)
k += 1
continue
all_pos.append(k)
cand_pos.append(k)
k += 1
sample_num = 0
for _ in range(len(cand_pos)):
if random.random() < mask_rate:
sample_num += 1
if mask_rate == 1:
sample_pos = cand_pos
else:
sample_pos = np_rand.choice(a=np.array(cand_pos), size=sample_num, replace=False).tolist()
sample_pos = sorted(sample_pos)
non_sample_pos = set(cand_pos) - set(sample_pos)
non_sample_pos = sorted(list(non_sample_pos))
for idx, j in enumerate(sample_pos):
if self.config.mask_input and mask_this:
if self.config.replace:
r = random.random()
if r < self.config.replace_rate:
tk_id = random.randint(0, int(tokenizer.vocab_size)-1)
mask_labels[i][j] = tk_id
mask_labels[i][j] = mask_labels[i][j]
else:
mask_labels[i][j] = self.mask_id
else:
mask_labels[i][j] = self.mask_id
masked_pos_shift[i][idx] = j + 1
masked_pos_non_shift[i][idx] = j
return mask_labels, masked_pos_shift, masked_pos_non_shift
def get_tks(self, y, tokenizer):
pre_output_ids = y.tolist()
output_ids = []
for i in range(len(pre_output_ids)):
output_id = []
for j in range(0, len(pre_output_ids[i])):
if pre_output_ids[i][j] == -100:
break
output_id.append(pre_output_ids[i][j])
output_ids.append(output_id)
tks = [
tokenizer.convert_ids_to_tokens(output_ids[i]) for i in
range(len(output_ids))]
return tks
def construct_gpt_return(self, lm_logits, labels, bs, masked_pos_shift, masked_pos_non_shift, outputs,
input_ids, mask_labels, tokenizer, ce=False, return_dict=True, loss=None):
device = lm_logits.device
if masked_pos_non_shift is not None:
probs = torch.softmax(lm_logits, dim=-1)
log_probs = torch.log(probs)
log_probs_all = log_probs.detach().clone()
seq_len = labels.shape[1]
mp = torch.zeros((bs, seq_len + 2)).to(device).scatter_(1, masked_pos_shift.to(device),
torch.ones((bs, seq_len + 1)).to(device))
mp = mp[:, 2:]
mp_long = mp.long()
ori_mp = mp.clone()
pads = torch.ones_like(labels, dtype=torch.long).to(device) * tokenizer.pad_token_id
other_2_pads_labels = labels * ~(labels.data.eq(-100) | labels.data.eq(tokenizer.eos_token_id)) + pads * (
labels.data.eq(-100) | labels.data.eq(tokenizer.eos_token_id)) # -100 -> 0#
_, max_ids = torch.max(log_probs, dim=-1)
y_b = other_2_pads_labels * (1 - mp_long) + max_ids * mp_long
_, s2, s3 = probs.shape
gt_prob = torch.gather(probs.reshape(-1, s3), dim=1, index=other_2_pads_labels.reshape(-1, 1))
gt_prob = gt_prob.reshape(-1, s2)
gt_log_probs = torch.log(gt_prob)
masked_ids = max_ids
if not return_dict:
raise NotImplementedError("not return_dict not implemented yet.")
output = CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
if labels is None:
return output
else:
return output, y_b, None, max_ids, masked_ids, input_ids, labels, log_probs, log_probs_all, \
mask_labels.to(input_ids.device), masked_pos_shift.to(input_ids.device), \
masked_pos_non_shift.to(input_ids.device), gt_log_probs
def split_gpt_return(self, lm_logits, input_ids, labels,masked_pos_shift,
masked_pos_non_shift, bs, return_dict, outputs, mask_labels, tokenizer, loss):
res= self.construct_gpt_return(lm_logits=lm_logits, labels=labels, bs=bs,
masked_pos_shift=masked_pos_shift, masked_pos_non_shift=masked_pos_non_shift, return_dict=return_dict,
outputs=outputs, input_ids=input_ids,
mask_labels=mask_labels, tokenizer=tokenizer, loss=loss)
return res
def format_and_padding(self, raw_src_txt, raw_tgt_txt, device, max_len=1024, pad_front=True,
format=False, tgt_max_len=256, cut_front=False, instruct_type='default', cut_src=False, tokenizer=None):
raw_src_txt = ["".join(input_text) for input_text in raw_src_txt]
def process_format(input_text, instruct_type='default'):
if not format:
return input_text
if instruct_type == 'metamath':
def get_input(query):
if query.find('\n') == -1:
return ''
return '\n'.join(query.split('\n')[1:])
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
example = {'instruction': input_text.split('\n')[0], 'input': get_input(input_text)}
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
input = prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(
example)
return input
if instruct_type == 'wizardlm':
problem_prompt = (
"{instruction}\n\n### Response:"
)
input = problem_prompt.format(instruction=input_text)
return input
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
input_text = PROMPT_DICT['prompt_no_input'].format(
instruction=input_text)
return input_text
raw_src_txt = [process_format(e, instruct_type=instruct_type) for e in raw_src_txt]
if cut_src:
ori = tokenizer.truncation_side
tokenizer.truncation_side = "left"
model_inputs = tokenizer(
raw_src_txt,
max_length=max_len - tgt_max_len,
truncation=True,
)
tokenizer.truncation_side = ori
raw_src_txt = tokenizer.batch_decode(model_inputs['input_ids'], skip_special_tokens=True)
raw_tgt_txt = [input_text for input_text in raw_tgt_txt]
input_txt = [txt + ' ' + raw_tgt_txt[i] for i, txt in enumerate(raw_src_txt)]
input_txt_ids = [tokenizer.encode(txt, truncation=False) for txt in input_txt]
label_txt_ids = [tokenizer.encode(txt, truncation=False) for txt in raw_tgt_txt]
input_txt_ids = [ids + [tokenizer.eos_token_id] for ids in input_txt_ids]
label_txt_ids = [ids + [tokenizer.eos_token_id] for ids in label_txt_ids]
label_txt_ids = [label_txt_id[1:] for label_txt_id in label_txt_ids]
label_masks = []
attention_masks = []
for i, input_txt_id in enumerate(input_txt_ids):
seq_len = len(input_txt_id)
label_txt_id = label_txt_ids[i]
label_len = len(label_txt_id)
label_mask = [0 if i < seq_len - label_len else 1 for i in range(seq_len)]
if np.sum(label_mask) == 0:
print(input_txt_id)
print(label_txt_id)
print(seq_len, label_len)
assert 1==0
label_masks.append(label_mask)
attention_mask = [1 for i in range(seq_len)]
attention_masks.append(attention_mask)
def add_padding(txts_ids, pad_id, max_len):
padding_txts_ids = []
batch_max_seq_len = max([len(txt) for txt in txts_ids])
batch_max_seq_len = min(batch_max_seq_len, max_len)
for txt_ids in txts_ids:
if cut_front:
txt_ids = txt_ids[-batch_max_seq_len:]
else:
txt_ids = txt_ids[:batch_max_seq_len]
if pad_front:
padding_txts_ids.append(
[pad_id] * (batch_max_seq_len - len(txt_ids)) + txt_ids)
else:
padding_txts_ids.append(
txt_ids + [pad_id] * (batch_max_seq_len - len(txt_ids)) )
return padding_txts_ids
padding_txts_ids = add_padding(input_txt_ids, pad_id=tokenizer.pad_token_id, max_len=max_len)
padding_label_txt_ids = add_padding(label_txt_ids, pad_id=-100, max_len=max_len)
padding_attention_mask = add_padding(attention_masks, pad_id=0, max_len=max_len)
padding_label_mask = add_padding(label_masks, pad_id=0, max_len=max_len)
return torch.tensor(padding_txts_ids, dtype=torch.long).to(device),\
torch.tensor(padding_label_txt_ids, dtype=torch.long).to(device), \
torch.tensor(padding_attention_mask, dtype=torch.long).to(device), \
torch.tensor(padding_label_mask, dtype=torch.long).to(device),

View File

@ -0,0 +1,389 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LLaMA model."""
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import transformers
from transformers.models.llama.modeling_llama import *
from config.decorator import replace
# from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
# from ...modeling_utils import PreTrainedModel
# from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
# from .configuration_llama import LlamaConfig
from models.mask_policy_utils import MaskPolicy
# logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"
import time
def timer(func):
def wrapper(*args, **kwargs):
start = time.time()
res = func(*args, **kwargs)
print('共耗时约 {:.2f}'.format(time.time() - start))
return res
return wrapper
@replace(LlamaForCausalLM)
class LlamaForCausalLM(LlamaForCausalLM):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
self.config = config
try:
self.mask_input = config.mask_input
self.mask_rate = config.mask_rate
except:
self.mask_input = False
self.mask_rate = 0
self.vocab_size = config.vocab_size
self.mask_policy = MaskPolicy( config=config, bpe_prefix='')
def sharing_encoder(self, flag):
self.model.is_sharing_encoder = flag
def forward (self,**kwargs):
if not self.training and "past_key_values" not in kwargs and 'not_seq_decode' not in kwargs:
bs = kwargs["input_ids"].shape[0]
start = time.time()
res = self.generate(
input_ids = kwargs["input_ids"],
attention_mask = kwargs["attention_mask"],
use_cache=True,
eos_token_id=self.tokenizer.eos_token_id,
min_new_tokens=10,
)
txt_ids = self.tokenizer.convert_ids_to_tokens(res[0])
pad = res.new_full(list(res.shape[:-1])+[max(self.config.max_length - res.shape[-1],0)],0)
res = torch.cat([res,pad],dim=-1)
res = res.view([bs,-1] + list(res.shape[1:]))
return {"loss":pad.new_full([1],0.0,dtype=torch.float32),"ids":res}
else:
if 'data_gt' in kwargs:
kwargs.pop('data_gt')
if 'data_src' in kwargs:
kwargs.pop('data_src')
if 'data_kd' in kwargs:
kwargs.pop('data_kd')
if 'not_seq_decode' in kwargs:
kwargs.pop('not_seq_decode')
res = self.mft_forward(**kwargs)
return res
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def mft_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
label_mask: torch.LongTensor = None,
mask_for: Optional[str] = None,
types: Optional = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# bs = None
mask_labels, masked_pos_shift, masked_pos_non_shift = None, None, None
bs = input_ids.shape[0]
if labels is not None:
mask_labels, masked_pos_shift, masked_pos_non_shift = self.mask_policy.get_gpt_masked_input(input_ids, label_mask, self.tokenizer, mask_for=mask_for, types=types)
ori_input_ids = input_ids.clone()
labels = input_ids.clone()
input_ids = mask_labels.clone()
if self.config.drop_attention_mask:
is_mask = input_ids.data.eq(self.mask_policy.mask_id)
attention_mask &= ~is_mask
if self.config.neftune_alpha != 0:
embeds_init = self.model.embed_tokens.forward(ori_input_ids)
input_mask = attention_mask
bs, seq_len = input_ids.shape
device = input_ids.device
mp = torch.zeros((bs, seq_len)).to(device).scatter_(1, masked_pos_non_shift.to(device),
torch.ones((bs, seq_len)).to(device))
input_lengths = torch.sum(input_mask, 1) # B
# input_lengths = torch.sum(is_mask, 1) # Bpu
noise_ = torch.zeros_like(embeds_init).uniform_(-1, 1)
delta = noise_ * input_mask.unsqueeze(2)
dims = input_lengths * embeds_init.size(-1)
mag = self.config.neftune_alpha / torch.sqrt(dims)
delta = (delta * mag.view(-1, 1, 1)).detach()
delta = delta.to(embeds_init)
if self.config.neft_all_token:
inputs_embeds = delta + embeds_init
else:
is_mask = input_ids.data.eq(self.mask_policy.mask_id)
delta = delta * is_mask.unsqueeze(-1)
inputs_embeds = delta + embeds_init
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
elif self.config.token_dropout != 0:
dropout = nn.Dropout(self.config.token_dropout)
embeds_init = self.model.embed_tokens.forward(ori_input_ids)
inputs_embeds = dropout(embeds_init)
if self.config.only_drop_target:
is_mask = input_ids.data.eq(self.mask_policy.mask_id)
inputs_embeds = inputs_embeds * is_mask.unsqueeze(-1) + embeds_init * (~is_mask.unsqueeze(-1))
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
elif self.config.token_noise:
if not self.config.token_noise_random:
with torch.no_grad():
pre_outputs = self.model(
input_ids=ori_input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pre_hidden_states = pre_outputs[0]
pre_logits = self.lm_head(pre_hidden_states)
pre_logits = pre_logits.float()
if self.config.token_noise_hard:
if self.config.token_noise_random:
vocab_size = int(self.tokenizer.vocab_size)
batch_size, seq_length = input_ids.shape
random_indices = torch.randint(vocab_size, (batch_size, seq_length), device=input_ids.device)
is_mask = input_ids.data.eq(self.mask_policy.mask_id)
ran_input_ids = input_ids * (~is_mask) + random_indices * is_mask
result = self.model.embed_tokens.weight[random_indices]
elif self.config.token_noise_sample:
batch_size, seq_length = input_ids.shape
pre_logits = pre_logits / self.config.temperature
probabilities = torch.softmax(pre_logits, dim=-1)
probabilities = probabilities.reshape(batch_size*seq_length, -1)
sampled_indices = torch.multinomial(probabilities, 1).squeeze(-1)
sampled_indices = sampled_indices.reshape(batch_size, seq_length)
result = self.model.embed_tokens.weight[sampled_indices]
else:
_, max_indices = torch.max(pre_logits, dim=-1)
result = self.model.embed_tokens.weight[max_indices]
else:
result = torch.matmul(torch.softmax(pre_logits, -1), self.model.embed_tokens.weight)
result = result[:,1:,:]
z = torch.zeros(result.shape[0], 1, result.shape[-1], device=result.device)
result = torch.cat([result, z], dim=1)
embeds_init = self.model.embed_tokens.forward(ori_input_ids)
delta = result
is_mask = input_ids.data.eq(self.mask_policy.mask_id)
if self.config.mixup:
delta = delta * 0.3 * is_mask.unsqueeze(-1)
embeds_init = 0.7 * embeds_init * (is_mask.unsqueeze(-1)) + embeds_init * (~is_mask.unsqueeze(-1))
inputs_embeds = delta + embeds_init
else:
delta = delta * 1 * is_mask.unsqueeze(-1)
embeds_init = embeds_init * (~is_mask.unsqueeze(-1))# + 0 * embeds_init * (is_mask.unsqueeze(-1))
inputs_embeds = delta + embeds_init
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
else:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :] #16*1023
label_mask = label_mask[..., 1:]
shift_logits = shift_logits.contiguous()
shift_labels = labels[..., 1:]
shift_labels = shift_labels.contiguous()
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
criterion = nn.CrossEntropyLoss(ignore_index=-100)
label_mask = label_mask.contiguous().view(-1)
shift_labels_ignore = shift_labels * label_mask + (1-label_mask) * -100
loss = criterion(shift_logits, shift_labels_ignore) # 16*1023
if labels is None:
labels_ = None
logits_ = logits
else:
labels_ = labels[..., 1:]
logits_ = logits[..., :-1, :]
res = self.mask_policy.split_gpt_return(logits_, input_ids, labels_, masked_pos_shift, masked_pos_non_shift, bs,
return_dict, outputs, mask_labels, self.tokenizer, loss)
return res
# return CausalLMOutputWithPast(
# loss=loss,
# logits=logits,
# past_key_values=outputs.past_key_values,
# hidden_states=outputs.hidden_states,
# attentions=outputs.attentions,
# )
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past

View File

@ -0,0 +1,407 @@
# coding=utf-8
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Mistral model."""
import inspect
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
# from ...activations import ACT2FN
# from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
# from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
# from ...modeling_utils import PreTrainedModel
# from ...utils import (
# add_start_docstrings,
# add_start_docstrings_to_model_forward,
# is_flash_attn_2_available,
# logging,
# replace_return_docstrings,
# )
# from .configuration_mistral import MistralConfig
from transformers.models.mistral.modeling_mistral import *
from models.mask_policy_utils import MaskPolicy
from config.decorator import replace
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MistralConfig"
@replace(MistralForCausalLM)
class MistralForCausalLM(MistralForCausalLM):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = MistralModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
self.config = config
try:
self.mask_input = config.mask_input
self.mask_rate = config.mask_rate
except:
self.mask_input = False
self.mask_rate = 0
self.vocab_size = config.vocab_size
self.mask_policy = MaskPolicy(config=config, bpe_prefix='')
def forward (self,**kwargs):
if not self.training and "past_key_values" not in kwargs and 'not_seq_decode' not in kwargs:
bs = kwargs["input_ids"].shape[0]
start = time.time()
res = self.generate(
input_ids = kwargs["input_ids"],
attention_mask = kwargs["attention_mask"],
use_cache=True,
eos_token_id=self.tokenizer.eos_token_id,
min_new_tokens=10,
)
txt_ids = self.tokenizer.convert_ids_to_tokens(res[0])
pad = res.new_full(list(res.shape[:-1])+[max(self.config.max_length - res.shape[-1],0)],0)
res = torch.cat([res,pad],dim=-1)
res = res.view([bs,-1] + list(res.shape[1:]))
return {"loss":pad.new_full([1],0.0,dtype=torch.float32),"ids":res}
else:
if 'data_gt' in kwargs:
kwargs.pop('data_gt')
if 'data_src' in kwargs:
kwargs.pop('data_src')
if 'data_kd' in kwargs:
kwargs.pop('data_kd')
if 'not_seq_decode' in kwargs:
kwargs.pop('not_seq_decode')
res = self.mft_forward(**kwargs)
return res
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def mft_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
label_mask: torch.LongTensor = None,
mask_for: Optional[str] = None,
types: Optional = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MistralForCausalLM
>>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
'''
'''
mask_labels, masked_pos_shift, masked_pos_non_shift = None, None, None
bs = input_ids.shape[0]
if labels is not None:
mask_labels, masked_pos_shift, masked_pos_non_shift = self.mask_policy.get_gpt_masked_input(input_ids, label_mask, self.tokenizer, mask_for=mask_for, types=types)
ori_input_ids = input_ids.clone()
labels = input_ids.clone()
input_ids = mask_labels.clone()
'''
'''
if self.config.drop_mask:
is_mask = input_ids.data.eq(self.mask_policy.mask_id)
attention_mask &= ~is_mask
if self.config.neftune_alpha != 0:
embeds_init = self.model.embed_tokens.forward(ori_input_ids)
input_mask = attention_mask
bs, seq_len = input_ids.shape
device = input_ids.device
mp = torch.zeros((bs, seq_len)).to(device).scatter_(1, masked_pos_non_shift.to(device),
torch.ones((bs, seq_len)).to(device))
# is_mask = mp.data.eq(1)
input_lengths = torch.sum(input_mask, 1) # B
noise_ = torch.zeros_like(embeds_init).uniform_(-1, 1)
delta = noise_ * input_mask.unsqueeze(2)
dims = input_lengths * embeds_init.size(-1)
mag = self.config.neftune_alpha / torch.sqrt(dims)
delta = (delta * mag.view(-1, 1, 1)).detach()
delta = delta.to(embeds_init)
if self.config.neft_all_token:
inputs_embeds = delta + embeds_init
else:
is_mask = input_ids.data.eq(self.mask_policy.mask_id)
delta = delta * is_mask.unsqueeze(-1)
inputs_embeds = delta + embeds_init
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
elif self.config.token_dropout != 0:
dropout = nn.Dropout(self.config.token_dropout)
embeds_init = self.model.embed_tokens.forward(ori_input_ids)
inputs_embeds = dropout(embeds_init)
if self.config.only_drop_target:
is_mask = input_ids.data.eq(self.mask_policy.mask_id)
inputs_embeds = inputs_embeds * is_mask.unsqueeze(-1) + embeds_init * (~is_mask.unsqueeze(-1))
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
elif self.config.token_noise:
if not self.config.token_noise_random:
with torch.no_grad():
pre_outputs = self.model(
input_ids=ori_input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pre_hidden_states = pre_outputs[0]
pre_logits = self.lm_head(pre_hidden_states)
pre_logits = pre_logits.float()
if self.config.token_noise_hard:
if self.config.token_noise_random:
vocab_size = int(self.tokenizer.vocab_size)
batch_size, seq_length = input_ids.shape
random_indices = torch.randint(vocab_size, (batch_size, seq_length),
device=input_ids.device)
is_mask = input_ids.data.eq(self.mask_policy.mask_id)
ran_input_ids = input_ids * (~is_mask) + random_indices * is_mask
result = self.model.embed_tokens.weight[random_indices]
elif self.config.token_noise_sample:
batch_size, seq_length = input_ids.shape
pre_logits = pre_logits / self.config.temperature
probabilities = torch.softmax(pre_logits, dim=-1)
probabilities = probabilities.reshape(batch_size * seq_length, -1)
sampled_indices = torch.multinomial(probabilities, 1).squeeze(-1)
sampled_indices = sampled_indices.reshape(batch_size, seq_length)
result = self.model.embed_tokens.weight[sampled_indices]
else:
_, max_indices = torch.max(pre_logits, dim=-1)
result = self.model.embed_tokens.weight[max_indices]
else:
result = torch.matmul(torch.softmax(pre_logits, -1), self.model.embed_tokens.weight)
result = result[:, 1:, :]
z = torch.zeros(result.shape[0], 1, result.shape[-1], device=result.device)
result = torch.cat([result, z], dim=1)
embeds_init = self.model.embed_tokens.forward(ori_input_ids)
delta = result
is_mask = input_ids.data.eq(self.mask_policy.mask_id)
if self.config.mixup:
delta = delta * 0.3 * is_mask.unsqueeze(-1)
embeds_init = 0.7 * embeds_init * (is_mask.unsqueeze(-1)) + embeds_init * (~is_mask.unsqueeze(-1))
inputs_embeds = delta + embeds_init
else:
delta = delta * 1 * is_mask.unsqueeze(-1)
embeds_init = embeds_init * (~is_mask.unsqueeze(-1))
inputs_embeds = delta + embeds_init
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
else:
if self.config.cmp:
vocab_size = int(self.tokenizer.vocab_size)
batch_size, seq_length = input_ids.shape
random_indices = torch.randint(vocab_size, (batch_size, seq_length), device=input_ids.device)
is_mask = input_ids.data.eq(self.mask_policy.mask_id)
ran_input_ids = input_ids * (~is_mask) + random_indices * is_mask
result = self.model.embed_tokens.weight[random_indices]
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :] # 16*1023
label_mask = label_mask[..., 1:]
shift_logits = shift_logits.contiguous()
shift_labels = labels[..., 1:]
shift_labels = shift_labels.contiguous()
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
criterion = nn.CrossEntropyLoss(ignore_index=-100)
label_mask = label_mask.contiguous().view(-1)
shift_labels_ignore = shift_labels * label_mask + (1 - label_mask) * -100
loss = criterion(shift_logits, shift_labels_ignore) # 16*1023
if labels is None:
labels_ = None
logits_ = logits
else:
labels_ = labels[..., 1:]
logits_ = logits[..., :-1, :]
res = self.mask_policy.split_gpt_return(logits_, input_ids, labels_, masked_pos_shift, masked_pos_non_shift, bs,
return_dict, outputs, mask_labels, self.tokenizer, loss)
return res
# if not return_dict:
# output = (logits,) + outputs[1:]
# return (loss,) + output if loss is not None else output
#
# return CausalLMOutputWithPast(
# loss=loss,
# logits=logits,
# past_key_values=outputs.past_key_values,
# hidden_states=outputs.hidden_states,
# attentions=outputs.attentions,
# )
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
# Omit tokens covered by past_key_values
if past_key_values:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past

BIN
MaskedThought/overview.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 255 KiB

View File

@ -0,0 +1,26 @@
setuptools
mkl
datasets==2.0.0
pathos
nltk
pytest
scikit-learn
sentencepiece
transformers==4.36.2
huggingface_hub
accelerate
protobuf
azureml-core
azureml-sdk
matplotlib
pandas
chardet
access_dict_by_dot
boto3
psutil
pyemd
spacy
tensorboard
peft
bitsandbytes
vllm

View File

@ -0,0 +1,19 @@
import importlib
import os
registered_outputter = dict()
def register_outputter(name):
def wrapper(outputter_cls):
registered_outputter[name] = outputter_cls
return outputter_cls
return wrapper
outputter_dir = os.path.dirname(__file__)
for f in os.listdir(outputter_dir):
fpath = os.path.join(outputter_dir,f)
if not f.startswith('.') and (f.endswith('.py')):
fname = f[:f.find('.py')]
module = importlib.import_module(f'.{fname}','trainer.Outputter')
for key,cls in registered_outputter.items():
globals()[key] = cls

View File

@ -0,0 +1,281 @@
import torch
import os
import numpy as np
from scipy.special import softmax
import transformers
from . import register_outputter
from transformers import AutoTokenizer
from pathos.helpers import mp
from typing import List, Union, Tuple
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
class DummyOutQueue:
def __init__(self, output_f):
self.output_f = output_f
def put(self, line):
self.output_f.write(line)
class MPOutputter:
"""
Outputter for generatl task. Worker flow is
`pred_value(token_ids)` put in queue
==> postprocess worker fetch queue and decode
==> writer worker fetch from out_queue and write out
"""
def __init__(self,args,model_name=None, **kwargs):
self.output_dir = args.output_dir
self.result_file = args.result_file
self.result_header = list(filter(lambda x: len(x) > 0,args.result_header.split(",")))
self.kept_header = set(self.result_header)
tokenizer = kwargs.pop("tokenizer", None)
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(model_name)
if args.world_size > 1:
self.of =open(os.path.join(self.output_dir,self.result_file+"_"+str(args.local_rank)),mode='w',encoding="utf-8")
else:
self.of =open(os.path.join(self.output_dir,self.result_file),mode='w',encoding="utf-8")
print("\033[0;37;41m\t{}\033[0m".format(self.of))
if args.outputter_num_workers is None:
self.workers = args.dataloader_num_workers
else:
self.workers = args.outputter_num_workers
if self.workers == 0: # disable multi-processing
self.in_queue = None
self.out_queue = DummyOutQueue(self.of)
else:
self.in_queue = mp.Queue()
self.out_queue = mp.Queue()
self.end_process_cnt = 0
def realtime_output(self,batch,pred):
if isinstance(pred,dict):
pred_value = {key:pred[key].cpu() for key in pred if pred[key] != None}
elif isinstance(pred,tuple):
pred_value = [p.cpu() for p in pred if p!=None]
else:
pred_value = pred.cpu()
return self.return_res(batch,pred_value)
def return_res(self,batch,pred):
NotImplementedError
def __call__(self,batch,pred):
if batch == None:
self.in_queue.put((None,None))
return
batch_value = {key:batch[key] for key in self.kept_header}
if isinstance(pred,dict):
pred_value = {key:pred[key].cpu() for key in pred if pred[key] != None}
elif isinstance(pred,tuple):
pred_value = [p.cpu() for p in pred]
else:
pred_value = pred.cpu()
if self.workers == 0:
self.out_batch(batch_value,pred_value)
else:
self.in_queue.put((batch_value,pred_value))
def writer_process(self):
#print(len(self.out_queue))
while True:
line = self.out_queue.get()
if line == None:
self.end_process_cnt += 1
if self.end_process_cnt == self.workers:
print("Writer recieve end notice!")
break
else:
self.of.write(line)
self.of.flush()
def post_process(self):
while True:
batch,pred = self.in_queue.get()
if batch == None:
print("Recieve end notice!")
self.out_queue.put(None)
break
self.out_batch(batch,pred)
def start_process(self):
if self.workers ==0:
pass
self.proc = []
for i in range(0,self.workers):
self.proc.append(mp.Process(target=self.post_process, args=()))
self.proc[i].start()
self.writer_proc = mp.Process(target=self.writer_process,args=())
self.writer_proc.start()
def close(self):
if self.workers > 0:
for i in range(0,self.workers):
self(None,None)
for p in self.proc:
p.join()
self.writer_proc.join()
self.of.close()
class BasicOutputter:
def __init__(self,args,model_name=None):
self.output_dir = args.output_dir
self.result_file = args.result_file
self.result_header = args.result_header.split(",")
self.of = open(os.path.join(self.output_dir,self.result_file),mode='w',encoding='utf-8')
def output(self,batch,pred):
self.out_batch(batch,pred)
def close(self):
self.of.close()
@register_outputter("generation")
class Seq2SeqOutputter(MPOutputter):
def __init__(self, args, model_name, **kwargs):
super().__init__(args, model_name, **kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if self.workers:
self.start_process()
def return_res(self,batch,res):
assert 1==0
rewrite = res[1] # 0 is fake loss
res = []
print(len(rewrite))
for i in range(0,len(rewrite)):
print(len(rewrite[i]))
for j in range(0,len(rewrite[i])):
newline = [self.tokenizer.decode(rewrite[i][j],skip_special_tokens=True,clean_up_tokenization_spaces=False),str(j)]
res.append(newline)
return res
def out_batch(self,batch,res):
rewrite = res[1] # 0 is fake loss
for i in range(0,len(rewrite)):
line = []
for key in self.result_header:
line.append(batch[key][i])
for j in range(0, len(rewrite[i])):
#output_tokens = self.tokenizer.convert_ids_to_tokens(rewrite[i][j])
#print(output_tokens)
res_score = [self.tokenizer.decode(rewrite[i][j],skip_special_tokens=True,clean_up_tokenization_spaces=False),str(j)]
#print(res_score)
newline = "\t".join(line + res_score)+"\n"
self.out_queue.put(newline)
@register_outputter("llama_generation")
class Seq2SeqOutputter(MPOutputter):
def __init__(self, args, model_name, **kwargs):
super().__init__(args, model_name, **kwargs)
self.tokenizer = kwargs['tokenizer']
self.args = args
if self.workers:
self.start_process()
def return_res(self, batch, res):
assert 1 == 0
rewrite = res[1] # 0 is fake loss
res = []
print(len(rewrite))
for i in range(0, len(rewrite)):
print(len(rewrite[i]))
for j in range(0, len(rewrite[i])):
newline = [
self.tokenizer.decode(rewrite[i][j], skip_special_tokens=True, clean_up_tokenization_spaces=False),
str(j)]
res.append(newline)
return res
def out_batch(self, batch, res):
rewrite = res[1] # 0 is fake loss
for i in range(0, len(rewrite)):
line = []
for key in self.result_header:
line.append(batch[key][i])
for j in range(0, len(rewrite[i])):
# output_tokens = self.tokenizer.convert_ids_to_tokens(rewrite[i][j])
# print(output_tokens)
def hh_process(line, src_line):
# print('line1---------------------',line)
# print('src_line', src_line.replace('<#SEP#>', '').replace("<|prompter|>", "\n\nHuman: ").replace("<|assistant|>",
# "\n\nAssistant: ").strip())
input_text = src_line.split('<#SEP#>')
input_text = " ".join(input_text)
input_text = input_text.replace("<|prompter|>", "\n\nHuman: ").replace("<|assistant|>",
"\n\nAssistant: ")
input_text = PROMPT_DICT['prompt_input'].format(
instruction='Give a response as the assistant with the input conversation history',
input=input_text)
self.tokenizer.truncation_side = "left"
input_text = self.tokenizer.encode(input_text,
return_tensors='np', max_length=self.args.tok_max_length - 128, truncation=True)
input_text = self.tokenizer.decode(input_text[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
print('line1---------------------\n', line)
print('src_line---------------------\n',input_text)
# line = line.replace(input_text, ' ').strip()
# print('line2',line)
# print('line3', line[len(input_text):])
# assert 1==0
# p = line.rfind('Assistant: ')
# output_line = line[p + len('Assistant: '):]
output_line = line[len(input_text):]
print('output1--------------------\n', output_line)
p = output_line.find('Human: ')
if p != -1:
output_line = output_line[:p]
p = output_line.find('Assistant: ')
if p != -1:
output_line = output_line[:p]
output_line = output_line.strip()
if len(output_line) == 0:
output_line = 'none'
output_line= output_line. \
replace('<pad>', ' '). \
replace('[PAD]', ' '). \
replace('<unk>', ' '). \
replace('<s>', ' '). \
replace('</s>', ' '). \
replace('\n', '<#LF#>'). \
strip()
print('output-------------------------\n', output_line)
return output_line
self.tokenizer.truncation_side = "right"
res_score = [
hh_process(self.tokenizer.decode(rewrite[i][j], skip_special_tokens=True, clean_up_tokenization_spaces=False), line[i]),
str(j)]
newline = "\t".join(line + res_score) + "\n"
self.out_queue.put(newline)

View File

@ -0,0 +1,19 @@
import importlib
import os
registered_trainer = dict()
def register_trainer(name):
def wrapper(trainer_fn):
registered_trainer[name] = trainer_fn
return trainer_fn
return wrapper
trainer_dir = os.path.dirname(__file__)
for f in os.listdir(trainer_dir):
fpath = os.path.join(trainer_dir,f)
if not f.startswith('.') and (f.endswith('.py')):
fname = f[:f.find('.py')]
module = importlib.import_module(f'.{fname}','trainer.Trainer')
for key,fn in registered_trainer.items():
globals()[key] = fn

View File

@ -0,0 +1,964 @@
import os
from dataclasses import dataclass
import numpy as np
from numpy.lib.function_base import average, median
import torch
from torch.distributed.distributed_c10d import init_process_group
from transformers import Trainer as BaseTrainer
# from trainer.Trainer.trainer import Trainer as BaseTrainer
# from trainer.Trainer.trainer import MetricsCalculator, AmlLogger
from data.data_reader import CustomizeCollator
import trainer.Metrics as Metrics
import trainer.Outputter as Outputter
from transformers.modeling_utils import unwrap_model
from transformers import logging
try:
from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint
except:
pass
# try:
# from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
# except:
# assert 1==0
# def is_deepspeed_available():
# return False
from packaging import version
from transformers.trainer_callback import PrinterCallback,TrainerCallback, TrainerState
from config.decorator import replace
from torch.utils.data.dataloader import DataLoader
import copy
from torch.utils.data.dataset import Dataset, IterableDataset
from transformers.file_utils import is_datasets_available
from transformers.trainer_pt_utils import IterableDatasetShard, LabelSmoother
from transformers.trainer_utils import TrainOutput, has_length, speed_metrics
import math
import nltk
import torch.nn.functional as F
import random
try:
from peft import PeftModelForSeq2SeqLM, PeftModelForCausalLM
except:
pass
import time
from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
ShardSampler,
distributed_broadcast_scalars,
distributed_concat,
find_batch_size,
get_parameter_names,
nested_concat,
nested_detach,
nested_numpify,
nested_truncate,
nested_xla_mesh_reduce,
reissue_pt_warnings,
get_model_param_count,
)
from data.tokenizer_utils import *
from torch import nn
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import math
import torch.distributed as dist
from transformers import get_linear_schedule_with_warmup
from trainer.Trainer import register_trainer
from collections import deque
import collections
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from torch.utils.data.distributed import DistributedSampler
if is_datasets_available():
import datasets
import logging as local_logging
logger = logging.get_logger(__name__)
logger.setLevel('INFO')
local_logging.basicConfig(format="[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s",level=logging.INFO)
from transformers.integrations import AzureMLCallback
from transformers.utils import (
CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
can_return_loss,
find_labels,
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
is_datasets_available,
is_in_notebook,
is_ipex_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_compile_available,
is_torch_neuroncore_available,
is_torch_tpu_available,
logging,
strtobool,
)
if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches
from accelerate import __version__ as accelerate_version
from accelerate.utils import DistributedDataParallelKwargs, GradientAccumulationPlugin
if version.parse(accelerate_version) > version.parse("0.20.3"):
from accelerate.utils import (
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
save_fsdp_optimizer,
)
class HisStatisticInfo:
def __init__(self, args):
self.args = args
self.his_dif_cont_rate = []
self.his_ce_loss = []
self.his_loss = []
self.his_2_gram_loss = []
self.his_3_gram_loss = []
self.his_4_gram_loss = []
self.his_2_gram_acc = []
self.his_3_gram_acc = []
self.his_4_gram_acc = []
self.his_1_gram_ent = []
self.his_2_gram_ent = []
self.his_3_gram_ent = []
self.his_4_gram_ent = []
self.his_static_r = []
self.his_forward_time = []
self.his_backward_time = []
self.mask_rate = 0
self.print_every = args.print_every
def update(self):
self.his_ce_loss = self.his_ce_loss[-self.print_every:]
self.his_loss = self.his_loss[-self.print_every:]
self.his_2_gram_acc = self.his_2_gram_acc[-self.print_every:]
self.his_3_gram_acc = self.his_3_gram_acc[-self.print_every:]
self.his_4_gram_acc = self.his_4_gram_acc[-self.print_every:]
self.his_2_gram_loss = self.his_2_gram_loss[-self.print_every:]
self.his_3_gram_loss = self.his_3_gram_loss[-self.print_every:]
self.his_4_gram_loss = self.his_4_gram_loss[-self.print_every:]
self.his_1_gram_ent = self.his_1_gram_ent[-self.print_every:]
self.his_2_gram_ent = self.his_2_gram_ent[-self.print_every:]
self.his_3_gram_ent = self.his_3_gram_ent[-self.print_every:]
self.his_4_gram_ent = self.his_4_gram_ent[-self.print_every:]
self.his_dif_cont_rate = self.his_dif_cont_rate[-self.print_every:]
self.his_static_r = self.his_static_r[-self.print_every:]
def print_info(self, tokenizer, step):
logger.info('At step {}, his_ce_loss {}'.format(step, np.mean(self.his_ce_loss)))
logger.info('At step {}, his_loss {}'.format(step, np.mean(self.his_loss)))
logger.info('At step {}, his_dif_cont_rate'.format(step, np.mean(self.his_dif_cont_rate)))
logger.info('At step {}, gram_acc: 2:{}, 3:{}, 4:{}'.format(step, np.mean(self.his_2_gram_acc),
np.mean(self.his_3_gram_acc),
np.mean(self.his_4_gram_acc)))
logger.info('At step {}, gram_loss: 2:{}, 3:{}, 4:{}'.format(step, np.mean(self.his_2_gram_loss),
np.mean(self.his_3_gram_loss),
np.mean(self.his_4_gram_loss)))
logger.info('At step {}, gram_ent: 1: {}, 2:{}, 3:{}, 4:{}'.format(step,np.mean(self.his_1_gram_ent), np.mean(self.his_2_gram_ent),
np.mean(self.his_3_gram_ent),
np.mean(self.his_4_gram_ent)))
logger.info('At step {}, his_forward_time {}'.format(step, np.mean(self.his_forward_time)))
logger.info('At step {}, his_backward_time {}'.format(step, np.mean(self.his_backward_time)))
logger.info('At step {}, mask_rate {}'.format(step, np.mean(self.mask_rate)))
logger.info('At step {}, lr {}'.format(step, self.lr))
@register_trainer("trainer436")
class Trainer(BaseTrainer):
def __init__(self, model, args, model_args, task_args, train_dataset, eval_dataset, auto_tokenizer):
data_collator = CustomizeCollator(train_dataset, eval_dataset)
metrics_calculator = None
super().__init__(
model=model,
args=args,
train_dataset=train_dataset.get_dataset() if train_dataset else None,
eval_dataset=eval_dataset.get_dataset() if eval_dataset else None,
data_collator=data_collator,
compute_metrics=metrics_calculator
)
self.args = args
self.tokenizer = auto_tokenizer
auto_tokenizer.truncation_side = "right"
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
if isinstance(model.module, PeftModelForCausalLM) or isinstance(model.module, PeftModelForSeq2SeqLM):
self.model.module.base_model.model.tokenizer = auto_tokenizer
else:
self.model.module.tokenizer = auto_tokenizer
else:
try:
if isinstance(model, PeftModelForCausalLM) or isinstance(model, PeftModelForSeq2SeqLM):
self.model.base_model.model.tokenizer = auto_tokenizer
else:
self.model.tokenizer = auto_tokenizer
except:
self.model.tokenizer = auto_tokenizer
self.his_info = HisStatisticInfo(args)
def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
):
self.accelerator.free_memory()
self._train_batch_size = batch_size
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
# Setting up training control variables:
# number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
len_dataloader = None
num_train_tokens = None
if has_length(train_dataloader):
len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
num_examples = self.num_examples(train_dataloader)
if args.max_steps > 0:
max_steps = args.max_steps
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
args.max_steps % num_update_steps_per_epoch > 0
)
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
# the best we can do.
num_train_samples = args.max_steps * total_train_batch_size
if args.include_tokens_per_second:
num_train_tokens = (
self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
)
else:
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
if args.include_tokens_per_second:
num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
max_steps = args.max_steps
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_train_epochs = sys.maxsize
num_update_steps_per_epoch = max_steps
num_examples = total_train_batch_size * args.max_steps
num_train_samples = args.max_steps * total_train_batch_size
if args.include_tokens_per_second:
num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
else:
raise ValueError(
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
f" {args.max_steps}"
)
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
if self.args.n_gpu > 1:
# nn.DataParallel(model) replicates the model, creating new variables and module
# references registered here no longer work on other gpus, breaking the module
raise ValueError(
"Currently --debug underflow_overflow is not supported under DP. Please use DDP"
" (torch.distributed.launch)."
)
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
self.lr_scheduler = None
self._created_lr_scheduler = False
if self.is_deepspeed_enabled:
self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
if not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
if args.gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
else:
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
model = self._wrap_model(self.model_wrapped)
# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False
if delay_optimizer_creation:
if use_accelerator_prepare:
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# prepare using `accelerator` prepare
if use_accelerator_prepare:
self.model.train()
if hasattr(self.lr_scheduler, "step"):
if self.use_apex:
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
if self.is_fsdp_enabled:
self.model = self.model_wrapped = model
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped
# ckpt loading
if resume_from_checkpoint is not None:
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)
elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)
# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
# FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
# Train!
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
logger.info(f" Num Epochs = {num_train_epochs:,}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
if self.args.per_device_train_batch_size != self._train_batch_size:
logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps:,}")
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
self.max_steps = max_steps
self.state.epoch = 0
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None
# Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch."
)
# Update the references
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
if self.hp_name is not None and self._trial is not None:
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
# parameter to Train when using DDP.
self.state.trial_name = self.hp_name(self._trial)
if trial is not None:
assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
self.state.trial_params = hp_params(assignments)
else:
self.state.trial_params = None
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(args.device)
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
sampler = get_dataloader_sampler(train_dataloader)
sampler_kinds = [RandomSampler]
if version.parse(accelerate_version) > version.parse("0.23.0"):
sampler_kinds.append(SeedableRandomSampler)
is_random_sampler = isinstance(sampler, tuple(sampler_kinds))
if is_torch_less_than_1_11 or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader:
break
else:
# Otherwise we need to call the whooooole sampler cause there is some random operation added
# AT THE VERY END!
sampler = sampler if sampler is not None else []
_ = list(sampler)
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_iterator = train_dataloader
if hasattr(epoch_iterator, "set_epoch"):
epoch_iterator.set_epoch(epoch)
# Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0:
self._past = None
steps_in_epoch = (
len(epoch_iterator)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True
step = -1
for step, inputs in enumerate(epoch_iterator):
total_batched_samples += 1
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
steps_trained_progress_bar = None
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
with self.accelerator.accumulate(model):
tr_loss_step = self.training_step(model, inputs)
self.his_info.his_ce_loss.append(tr_loss_step.float().clone().detach().cpu().numpy() * args.gradient_accumulation_steps)
self.his_info.his_loss.append(tr_loss_step.float().clone().detach().cpu().numpy() * args.gradient_accumulation_steps)
print_every = args.print_every
output_dir = args.output_dir
# if not args.not_save:
# if args.zero3:
# print_every = args.print_every
# save_every = args.save_every
# output_dir = args.output_dir
#
# if (step + 1) % save_every == 0:
# if args.local_rank <= 0:
# if not os.path.exists(output_dir + '/' + args.exp_name):
# os.makedirs(output_dir + '/' + args.exp_name)
# save_path = output_dir + '/' + args.exp_name + '/_model.{}_{}'.format(self.state.global_step,
# step)
# # logger.info('save to {}, epoch: {}'.format(save_path, self.state.epoch))
# logger.info('save to {}, epoch: {}'.format(save_path + '_adapter', self.state.epoch))
# self.accelerator.wait_for_everyone()
# if self.accelerator.is_main_process:
# self.tokenizer.save_pretrained(save_path + '_adapter')
# unwrapped_model = self.accelerator.unwrap_model(model)
# state_dict = self.accelerator.get_state_dict(model)
# unwrapped_model.save_pretrained(
# save_path + '_adapter', is_main_process=self.accelerator.is_main_process,
# save_function=self.accelerator.save, state_dict=state_dict
# )
# if args.local_rank <= 0:
# if (step + 1) % print_every == 0:
# self.his_info.lr = self.optimizer.param_groups[0]['lr']
# self.his_info.update()
# self.his_info.print_info(tokenizer=self.tokenizer, step=step)
# else:
#
# if (step + 1) % save_every == 0:
# if args.local_rank <= 0:
# if not os.path.exists(output_dir + '/' + args.exp_name):
# os.makedirs(output_dir + '/' + args.exp_name)
# save_path = output_dir + '/' + args.exp_name + '/_model.{}_{}'.format(self.state.global_step,
# step)
# # logger.info('save to {}, epoch: {}'.format(save_path, self.state.epoch))
# logger.info('save to {}, epoch: {}'.format(save_path + '_adapter', self.state.epoch))
#
# def safe_save_model_for_hf_trainer(trainer, output_dir):
# """Collects the state dict and dump to disk."""
# state_dict = trainer.model.state_dict()
#
# cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
# del state_dict
# trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
#
# safe_save_model_for_hf_trainer(trainer=self, output_dir=save_path + '_adapter')
if args.local_rank <= 0:
# unwrap_model(model).save_pretrained(save_path + '_adapter')
# self.tokenizer.save_pretrained(save_path + '_adapter')
if (step + 1) % print_every == 0:
self.his_info.lr = self.optimizer.param_groups[0]['lr']
self.his_info.update()
self.his_info.print_info(tokenizer=self.tokenizer, step=step)
if (
args.logging_nan_inf_filter
and not is_torch_tpu_available()
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
):
# if loss is nan or inf simply add the average of previous logged losses
tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
else:
tr_loss += tr_loss_step
self.current_flos += float(self.floating_point_ops(inputs))
is_last_step_and_steps_less_than_grad_acc = (
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
)
if (
total_batched_samples % args.gradient_accumulation_steps == 0
or
# last step in epoch but step is always smaller than gradient_accumulation_steps
is_last_step_and_steps_less_than_grad_acc
):
# the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
# in accelerate. So, explicitly enable sync gradients to True in that case.
if is_last_step_and_steps_less_than_grad_acc or (
version.parse(accelerate_version) <= version.parse("0.20.3")
):
self.accelerator.gradient_state._set_sync_gradients(True)
# Gradient clipping
if args.max_grad_norm is not None and args.max_grad_norm > 0:
# deepspeed does its own clipping
if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
elif self.use_apex:
# Revert to normal clipping otherwise, handling Apex or full precision
nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer),
args.max_grad_norm,
)
else:
self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
# Optimizer step
self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step()
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
if step < 0:
logger.warning(
"There seems to be not a single sample in your epoch_iterator, stopping training at step"
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
f" num_steps ({max_steps}) higher than the number of available samples."
)
self.control.should_training_stop = True
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
# if args.no_zero:
# if args.local_rank <= 0:
# if not os.path.exists(output_dir + '/' + args.exp_name):
# os.makedirs(output_dir + '/' + args.exp_name)
# save_path = output_dir + '/' + args.exp_name + '/_model.{}_{}'.format(self.state.global_step, step)
# # logger.info('save to {}, epoch: {}'.format(save_path, self.state.epoch))
# logger.info('save to {}, epoch: {}'.format(save_path + '_epoch' + str(epoch), self.state.epoch))
# unwrap_model(model).save_pretrained(save_path + '_epoch' + str(epoch))
# self.tokenizer.save_pretrained(save_path + '_epoch' + str(epoch))
if not args.not_save:
if args.zero3:
output_dir = args.output_dir
if args.local_rank <= 0:
if not os.path.exists(output_dir + '/' + args.exp_name):
os.makedirs(output_dir + '/' + args.exp_name)
save_path = output_dir + '/' + args.exp_name + '/_model.{}_{}'.format(self.state.global_step, step)
# logger.info('save to {}, epoch: {}'.format(save_path, self.state.epoch))
logger.info('save to {}, epoch: {}'.format(save_path + '_epoch' + str(epoch), self.state.epoch))
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
self.tokenizer.save_pretrained(save_path + '_epoch' + str(epoch))
unwrapped_model = self.accelerator.unwrap_model(model)
state_dict = self.accelerator.get_state_dict(model)
unwrapped_model.save_pretrained(
save_path + '_epoch' + str(epoch), is_main_process=self.accelerator.is_main_process,
save_function=self.accelerator.save, state_dict=state_dict
)
else:
output_dir = args.output_dir
if args.local_rank <= 0:
if not os.path.exists(output_dir + '/' + args.exp_name):
os.makedirs(output_dir + '/' + args.exp_name)
save_path = output_dir + '/' + args.exp_name + '/_model.{}_{}'.format(self.state.global_step,
step)
def safe_save_model_for_hf_trainer(trainer, output_dir):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
safe_save_model_for_hf_trainer(trainer=self, output_dir=save_path + '_epoch' + str(epoch))
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
else:
logger.warning(
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected."
)
if self.control.should_training_stop:
break
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# Wait for everyone to get here so we are sure the model has been saved by process 0.
if is_torch_tpu_available():
xm.rendezvous("load_best_model_at_end")
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
dist.barrier()
elif is_sagemaker_mp_enabled():
smp.barrier()
self._load_best_model()
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step
metrics = speed_metrics(
"train",
start_time,
num_samples=num_train_samples,
num_steps=self.state.max_steps,
num_tokens=num_train_tokens,
)
self.store_flos()
metrics["total_flos"] = self.state.total_flos
metrics["train_loss"] = train_loss
self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics)
self.log(metrics)
run_dir = self._get_output_dir(trial)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
# Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
for checkpoint in checkpoints_sorted:
if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
# Wait for the checkpoint to be uploaded.
self._finish_current_push()
# After training we make sure to retrieve back the original forward pass method
# for the embedding layer by removing the forward post hook.
if self.neftune_noise_alpha is not None:
self._deactivate_neftune(self.model)
return TrainOutput(self.state.global_step, train_loss, metrics)
def compute_acc_and_loss(self, max_ids, labels_, mp, log_probs_all):
'''
Calculateing accuracy and loss while generating n-gram related positional information.
'''
tot_cont = 0
dif_cont = 0
# acc max_ids -> labels
# loss sum of log_probs
# log_probs_all # bs, seq_len vocab
masked_probs = log_probs_all.gather(2, max_ids.unsqueeze(2)).squeeze()
pred_acc = max_ids == labels_
batch_p_numpy = mp.clone().detach().cpu().numpy()
batch_2_gram_pos = []
batch_3_gram_pos = []
batch_4_gram_pos = []
batch_n_gram_pos = []
labels_np = labels_.cpu().clone().numpy()
for k, p_n in enumerate(batch_p_numpy):
cont_mask_num = 0
_2_gram_pos = [0] * len(labels_np[k])
_3_gram_pos = [0] * len(labels_np[k])
_4_gram_pos = [0] * len(labels_np[k])
_n_gram_pos = [0] * len(labels_np[k])
for i in range(0, len(p_n)):
if p_n[i] == 0:
break
if i > 0 and p_n[i] == p_n[i - 1] + 1:
cont_mask_num += 1
elif i == 0: # 0 or not cont from last pos
cont_mask_num = 1
else:
cont_mask_num = 1
if p_n[i] + 1 < len(labels_np[k]) and labels_np[k][p_n[i] + 1] != -100:
if cont_mask_num >= 1:
_n_gram_pos[p_n[i] + 1] = 1
if cont_mask_num == 1:
_2_gram_pos[p_n[i] + 1] = 1
if cont_mask_num == 2:
_3_gram_pos[p_n[i] + 1] = 1
if cont_mask_num == 3:
_4_gram_pos[p_n[i] + 1] = 1
batch_2_gram_pos.append(_2_gram_pos)
batch_3_gram_pos.append(_3_gram_pos)
batch_4_gram_pos.append(_4_gram_pos)
batch_n_gram_pos.append(_n_gram_pos)
batch_2_gram_pos = torch.tensor(batch_2_gram_pos).long().cuda()
batch_3_gram_pos = torch.tensor(batch_3_gram_pos).long().cuda()
batch_4_gram_pos = torch.tensor(batch_4_gram_pos).long().cuda()
batch_n_gram_pos = torch.tensor(batch_n_gram_pos).long().cuda()
# print(batch_2_gram_pos.shape)
# print(log_probs_all.shape)
probs = torch.exp(log_probs_all)
entropy = -torch.sum(probs * log_probs_all, dim=-1) #bs, seq
# assert 1==0
uni_gram_pos = 1 - batch_n_gram_pos
_1_gram_entropy = (entropy * uni_gram_pos).sum() / (uni_gram_pos.sum())
_2_gram_entropy = (entropy * batch_2_gram_pos).sum() / (batch_2_gram_pos.sum())
_3_gram_entropy = (entropy * batch_3_gram_pos).sum() / (batch_3_gram_pos.sum())
_4_gram_entropy = (entropy * batch_4_gram_pos).sum() / (batch_4_gram_pos.sum())
_2_gram_loss = -(masked_probs * batch_2_gram_pos).sum() / (batch_2_gram_pos.sum())
_3_gram_loss = -(masked_probs * batch_3_gram_pos).sum() / (batch_3_gram_pos.sum())
_4_gram_loss = -(masked_probs * batch_4_gram_pos).sum() / (batch_4_gram_pos.sum())
_2_gram_acc = (pred_acc * batch_2_gram_pos).sum() / (batch_2_gram_pos.sum())
_3_gram_acc = (pred_acc * batch_3_gram_pos).sum() / (batch_3_gram_pos.sum())
_4_gram_acc = (pred_acc * batch_4_gram_pos).sum() / (batch_4_gram_pos.sum())
if uni_gram_pos.sum() != 0:
self.his_info.his_1_gram_ent.append(_1_gram_entropy.cpu())
if batch_2_gram_pos.sum() != 0:
self.his_info.his_2_gram_acc.append(_2_gram_acc.cpu())
self.his_info.his_2_gram_loss.append(_2_gram_loss.cpu())
self.his_info.his_2_gram_ent.append(_2_gram_entropy.cpu())
if batch_3_gram_pos.sum() != 0:
self.his_info.his_3_gram_acc.append(_3_gram_acc.cpu())
self.his_info.his_3_gram_loss.append(_3_gram_loss.cpu())
self.his_info.his_3_gram_ent.append(_3_gram_entropy.cpu())
if batch_4_gram_pos.sum() != 0:
self.his_info.his_4_gram_acc.append(_4_gram_acc.cpu())
self.his_info.his_4_gram_loss.append(_4_gram_loss.cpu())
self.his_info.his_4_gram_ent.append(_4_gram_entropy.cpu())
return batch_2_gram_pos, batch_3_gram_pos, batch_4_gram_pos, batch_n_gram_pos
def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
if self.args.update_mask_rate:
try:
model.mask_policy.config.mask_rate = 0 + self.args.max_mask_rate * float(self.state.global_step) / ((self.max_steps * self.args.mask_rate_warmup_ratio) + 1)
model.mask_policy.config.mask_rate = min(model.mask_policy.config.mask_rate, self.args.max_mask_rate)
self.his_info.mask_rate = model.mask_policy.config.mask_rate
except:
model.module.mask_policy.config.mask_rate = 0 + self.args.max_mask_rate * float(self.state.global_step) / (
(self.max_steps * self.args.mask_rate_warmup_ratio) + 1)
model.module.mask_policy.config.mask_rate = min(model.module.mask_policy.config.mask_rate, self.args.max_mask_rate)
self.his_info.mask_rate = model.module.mask_policy.config.mask_rate
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
device = inputs['input_ids'].device
ori_inputs = copy.deepcopy(inputs)
model.train()
batch_size = inputs['input_ids'].shape[0]
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
model.module.tokenizer = self.tokenizer
model.tokenizer = self.tokenizer
else:
model.tokenizer = self.tokenizer
gt_data = inputs['data_gt']
ce_cands = gt_data
def get_id_cands(cands):
inputs_local = {}
inputs_local['input_ids'], inputs_local['labels'], inputs_local['attention_mask'], \
inputs_local['label_mask'] = self.model.mask_policy.format_and_padding(inputs['data_src'], cands, device,
self.args.tok_max_length, self.args.pad_front, format=self.args.instruct_format,
tgt_max_len=self.args.tgt_max_length,
tokenizer=self.tokenizer,
instruct_type=self.args.instruct_type, cut_src=self.args.cut_src)
return inputs_local
start = time.time()
inputs_ce = get_id_cands(ce_cands)
model_return_dict = model(**inputs_ce)
outputs, _, _, max_ids, _, _, labels_, _, log_probs_all, \
_, _, masked_pos_non_shift, = model_return_dict[:12]
self.compute_acc_and_loss(
max_ids, labels_, masked_pos_non_shift, log_probs_all)
forward_time = time.time()-start
self.his_info.his_forward_time.append(forward_time)
start = time.time()
if labels is not None:
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss

View File

View File

@ -0,0 +1,42 @@
#!/bin/bash
export MASTER_ADDR="localhost"
export MASTER_PORT="1231"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export exp_name='my_llama2_7b_gsm8k_mft'
export base_model_name=Llama-2-7b-hf
python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node=8 --use_env main_llama_830.py \
--do_train \
--scene llama_generation \
--report_to none \
--seed 1 \
--trainer trainer436 \
--learning_rate 1e-5 \
--num_train_epochs 10 \
--warmup_ratio 0.01 \
--print_every 200 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 2 \
--per_device_eval_batch_size 1 \
--exp_name ${exp_name} \
--train_dir data/gsm8k_train.json\
--eval_dir data/gsm8k_train.json \
--gradient_checkpointing False \
--tok_max_length 512 \
--tgt_max_length 512 \
--pad_front False \
--lr_scheduler_type "cosine" \
--model ${base_model_name} \
--model_name ${base_model_name} \
--instruct_format True \
--bf16 True \
--tf32 True \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--output_dir '.' \
--max_mask_rate 0.4 \
--update_mask_rate True \
--mask_rate_warmup 0.66 \
--save_steps 117 \
--replace True \
--replace_rate 1

View File

@ -0,0 +1,40 @@
#!/bin/bash
export MASTER_ADDR="localhost"
export MASTER_PORT="1231"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export exp_name='my_llama2_7b_gsm8k_rft100_mft'
export base_model_name=Llama-2-7b-hf
python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node=8 --use_env main_llama_830.py \
--do_train \
--scene llama_generation \
--report_to none \
--seed 1 \
--trainer trainer436 \
--learning_rate 1e-5 \
--num_train_epochs 5 \
--warmup_ratio 0.02 \
--print_every 200 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 2 \
--per_device_eval_batch_size 1 \
--exp_name ${exp_name} \
--train_dir data/rft_100_2.json\
--eval_dir data/rft_100_2.json \
--gradient_checkpointing False \
--tok_max_length 512 \
--tgt_max_length 512 \
--pad_front False \
--lr_scheduler_type "cosine" \
--model ${base_model_name} \
--model_name ${base_model_name} \
--instruct_format True \
--bf16 True \
--tf32 True \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--output_dir '.' \
--max_mask_rate 0.4 \
--update_mask_rate True \
--mask_rate_warmup 0.66 \
--save_steps 734

View File

@ -0,0 +1,40 @@
#!/bin/bash
export MASTER_ADDR="localhost"
export MASTER_PORT="1231"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export exp_name='my_llama2_7b_gsm8k_rftu33_mft'
export base_model_name=Llama-2-7b-hf
python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node=8 --use_env main_llama_830.py \
--do_train \
--scene llama_generation \
--report_to none \
--seed 1 \
--trainer trainer436 \
--learning_rate 1e-5 \
--num_train_epochs 5 \
--warmup_ratio 0.02 \
--print_every 200 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 2 \
--per_device_eval_batch_size 1 \
--exp_name ${exp_name} \
--train_dir data/rft_u33.json\
--eval_dir data/rft_u33.json \
--gradient_checkpointing False \
--tok_max_length 512 \
--tgt_max_length 512 \
--pad_front False \
--lr_scheduler_type "cosine" \
--model ${base_model_name} \
--model_name ${base_model_name} \
--instruct_format True \
--bf16 True \
--tf32 True \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--output_dir '.' \
--max_mask_rate 0.4 \
--update_mask_rate True \
--mask_rate_warmup 0.66 \
--save_steps 1718

View File

@ -0,0 +1,40 @@
#!/bin/bash
export MASTER_ADDR="localhost"
export MASTER_PORT="1231"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export exp_name='my_llama2_7b_gsm8k_sft'
export base_model_name=Llama-2-7b-hf
python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node=8 --use_env main_llama_830.py \
--do_train \
--scene llama_generation \
--report_to none \
--seed 1 \
--trainer trainer436 \
--learning_rate 1e-5 \
--num_train_epochs 3 \
--warmup_ratio 0.03 \
--print_every 200 \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 2 \
--per_device_eval_batch_size 1 \
--exp_name ${exp_name} \
--train_dir data/gsm8k_train.json\
--eval_dir data/gsm8k_train.json \
--gradient_checkpointing False \
--tok_max_length 512 \
--tgt_max_length 512 \
--pad_front False \
--lr_scheduler_type "cosine" \
--model ${base_model_name} \
--model_name ${base_model_name} \
--instruct_format True \
--bf16 True \
--tf32 True \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--output_dir '.' \
--max_mask_rate 0 \
--update_mask_rate False \
--mask_rate_warmup 0.66 \
--save_steps 117

View File

@ -0,0 +1,41 @@
#!/bin/bash
export MASTER_ADDR="localhost"
export MASTER_PORT="1231"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export exp_name='my_llama2_7b_math_mft'
export base_model_name=Llama-2-7b-hf
python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node=8 --use_env main_llama_830.py \
--do_train \
--scene llama_generation \
--report_to none \
--seed 1 \
--trainer trainer436 \
--learning_rate 1e-5 \
--num_train_epochs 10 \
--warmup_ratio 0.01 \
--print_every 200 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--per_device_eval_batch_size 1 \
--exp_name ${exp_name} \
--train_dir data/math.json\
--eval_dir data/math.json \
--gradient_checkpointing False \
--tok_max_length 800 \
--tgt_max_length 512 \
--cut_src True \
--pad_front False \
--lr_scheduler_type "cosine" \
--model ${base_model_name} \
--model_name ${base_model_name} \
--instruct_format True \
--bf16 True \
--tf32 True \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--output_dir '.' \
--max_mask_rate 0.2 \
--update_mask_rate True \
--mask_rate_warmup 0.66 \
--save_steps 117