slowllama: split model prepare
This commit is contained in:
parent
bb47b24ef6
commit
a35b78e3b7
|
@ -3,4 +3,6 @@ data/**
|
|||
*.log
|
||||
frozen/**
|
||||
inputs/**
|
||||
model/**
|
||||
model/**
|
||||
llama7b_serve/**
|
||||
out/**
|
12
finetune.py
12
finetune.py
|
@ -30,11 +30,15 @@ gen_tokens = 32
|
|||
log_lora_grad = False
|
||||
log_lora_weight = True
|
||||
|
||||
#model_path = '../llama-2-7b'
|
||||
model_path = 'model'
|
||||
model_path = 'llama7b_serve'
|
||||
snapshots_path = 'out'
|
||||
finetune_file = './README.md'
|
||||
prompt = 'slowllama is a '
|
||||
|
||||
|
||||
if not os.path.exists(snapshots_path):
|
||||
os.makedirs(snapshots_path)
|
||||
|
||||
# data to finetune on
|
||||
with open(finetune_file) as f:
|
||||
text = f.read()
|
||||
|
@ -68,7 +72,7 @@ if __name__ == '__main__':
|
|||
|
||||
logging.info(f'loaded dataset: {len(tokens)} tokens')
|
||||
|
||||
model = load_frozen(model_path, compute_dtype=compute_dtype, offload_location=offload_to, served_model_path=served_model_path).to(device).to(compute_dtype)
|
||||
model = load_frozen(model_path, compute_dtype=compute_dtype).to(device).to(compute_dtype)
|
||||
|
||||
def get_batch(batch_size):
|
||||
index = torch.randint(len(tokens) - seq_len, (batch_size,))
|
||||
|
@ -100,4 +104,4 @@ if __name__ == '__main__':
|
|||
elif loss < last_loss:
|
||||
last_loss = loss
|
||||
logging.info(f'saving snapshot')
|
||||
torch.save(model.state_dict(), f'data/state_dict_{i}.pth')
|
||||
torch.save(model.state_dict(), os.path.join(snapshots_path, f'state_dict_{i}.pth'))
|
||||
|
|
|
@ -123,8 +123,9 @@ def load_frozen(path, **kwargs):
|
|||
|
||||
args = ModelArgs(**config)
|
||||
args.init_frozen = False
|
||||
args.served_model_path = path
|
||||
model = Transformer(args)
|
||||
model.load_state_dict(torch.load(os.path.join(args.served_model_path, 'model.pth')), strict=False)
|
||||
model.load_state_dict(torch.load(os.path.join(path, 'model.pth')), strict=False)
|
||||
return model
|
||||
|
||||
def add_lora(model_path, lora_path):
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# loads model in original llama2 format and saves to another folder in sequential format
|
||||
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from loader import load_llama2
|
||||
|
||||
seed = 54321
|
||||
device = 'mps' # mps for macbooks
|
||||
offload_to = 'disk'
|
||||
|
||||
compute_dtype = torch.float32
|
||||
|
||||
llama2_model_path = '../llama-2-7b'
|
||||
served_model_path = 'llama7b_serve'
|
||||
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO, filename='prepare_model.log')
|
||||
torch.random.manual_seed(seed)
|
||||
|
||||
model = load_llama2(llama2_model_path, compute_dtype=compute_dtype, offload_location=offload_to, served_model_path=served_model_path).to(device).to(compute_dtype)
|
|
@ -5,7 +5,7 @@ import os
|
|||
sys.path.insert(0, '../llama/llama')
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
from loader import load_llama2
|
||||
from loader import load_frozen
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -13,15 +13,11 @@ logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
|
|||
|
||||
model_path = sys.argv[1]
|
||||
device = sys.argv[2] if len(sys.argv) > 2 else 'cpu'
|
||||
lora_weights = sys.argv[3] if len(sys.argv) > 3 else None
|
||||
|
||||
tokenizer_path = os.path.join(model_path, 'tokenizer.model')
|
||||
tokenizer = Tokenizer(tokenizer_path)
|
||||
|
||||
model = load_llama2(sys.argv[1], dropout=0.0).to(device)
|
||||
|
||||
if lora_weights is not None:
|
||||
print(model.load_state_dict(torch.load(lora_weights), strict=False))
|
||||
model = load_frozen(sys.argv[1], dropout=0.0).to(device)
|
||||
|
||||
def greedy_gen(prompt, max_new_tokens=50):
|
||||
tokens = torch.tensor(tokenizer.encode(prompt, True, False)).view(1, -1).to(device)
|
||||
|
|
Loading…
Reference in New Issue