slowllama: split model prepare

This commit is contained in:
Oleksandr Kuvshynov 2023-09-27 23:45:18 -04:00
parent bb47b24ef6
commit a35b78e3b7
5 changed files with 35 additions and 12 deletions

4
.gitignore vendored
View File

@ -3,4 +3,6 @@ data/**
*.log
frozen/**
inputs/**
model/**
model/**
llama7b_serve/**
out/**

View File

@ -30,11 +30,15 @@ gen_tokens = 32
log_lora_grad = False
log_lora_weight = True
#model_path = '../llama-2-7b'
model_path = 'model'
model_path = 'llama7b_serve'
snapshots_path = 'out'
finetune_file = './README.md'
prompt = 'slowllama is a '
if not os.path.exists(snapshots_path):
os.makedirs(snapshots_path)
# data to finetune on
with open(finetune_file) as f:
text = f.read()
@ -68,7 +72,7 @@ if __name__ == '__main__':
logging.info(f'loaded dataset: {len(tokens)} tokens')
model = load_frozen(model_path, compute_dtype=compute_dtype, offload_location=offload_to, served_model_path=served_model_path).to(device).to(compute_dtype)
model = load_frozen(model_path, compute_dtype=compute_dtype).to(device).to(compute_dtype)
def get_batch(batch_size):
index = torch.randint(len(tokens) - seq_len, (batch_size,))
@ -100,4 +104,4 @@ if __name__ == '__main__':
elif loss < last_loss:
last_loss = loss
logging.info(f'saving snapshot')
torch.save(model.state_dict(), f'data/state_dict_{i}.pth')
torch.save(model.state_dict(), os.path.join(snapshots_path, f'state_dict_{i}.pth'))

View File

@ -123,8 +123,9 @@ def load_frozen(path, **kwargs):
args = ModelArgs(**config)
args.init_frozen = False
args.served_model_path = path
model = Transformer(args)
model.load_state_dict(torch.load(os.path.join(args.served_model_path, 'model.pth')), strict=False)
model.load_state_dict(torch.load(os.path.join(path, 'model.pth')), strict=False)
return model
def add_lora(model_path, lora_path):

20
prepare_model.py Normal file
View File

@ -0,0 +1,20 @@
# loads model in original llama2 format and saves to another folder in sequential format
import torch
import logging
from loader import load_llama2
seed = 54321
device = 'mps' # mps for macbooks
offload_to = 'disk'
compute_dtype = torch.float32
llama2_model_path = '../llama-2-7b'
served_model_path = 'llama7b_serve'
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO, filename='prepare_model.log')
torch.random.manual_seed(seed)
model = load_llama2(llama2_model_path, compute_dtype=compute_dtype, offload_location=offload_to, served_model_path=served_model_path).to(device).to(compute_dtype)

View File

@ -5,7 +5,7 @@ import os
sys.path.insert(0, '../llama/llama')
from tokenizer import Tokenizer
from loader import load_llama2
from loader import load_frozen
import logging
@ -13,15 +13,11 @@ logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
model_path = sys.argv[1]
device = sys.argv[2] if len(sys.argv) > 2 else 'cpu'
lora_weights = sys.argv[3] if len(sys.argv) > 3 else None
tokenizer_path = os.path.join(model_path, 'tokenizer.model')
tokenizer = Tokenizer(tokenizer_path)
model = load_llama2(sys.argv[1], dropout=0.0).to(device)
if lora_weights is not None:
print(model.load_state_dict(torch.load(lora_weights), strict=False))
model = load_frozen(sys.argv[1], dropout=0.0).to(device)
def greedy_gen(prompt, max_new_tokens=50):
tokens = torch.tensor(tokenizer.encode(prompt, True, False)).view(1, -1).to(device)