slowllama: fp16 e2e test

This commit is contained in:
Oleksandr Kuvshynov 2023-10-11 08:00:27 -04:00
parent 018912536b
commit f23d7bd929
6 changed files with 8 additions and 25 deletions

View File

@ -27,4 +27,4 @@ snapshots_path = 'out'
finetune_file = './test_data/cubestat.txt'
prompt = 'Cubestat reports the following metrics: '
llama2_model_path = '../llama-2-7b'
llama2_model_path = '../llama-2-7b'

View File

@ -3,4 +3,4 @@
python prepare_model.py
python test_gen.py
python finetune.py
python test_gen.py ./out/state_dict_19.pth
python test_gen.py ./out/state_dict_19.pth

View File

@ -6,7 +6,7 @@ from loader import load_frozen
from plot_lora import log_lora
from utils import Tokenizer, greedy_gen
from conf_fp32 import *
from conf_fp16 import *
if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s %(message)s', level=log_level, filename='logs/finetune.log')

View File

@ -4,11 +4,11 @@ import torch
import logging
from loader import prepare_model
from conf_fp32 import *
from conf_fp16 import *
logging.basicConfig(format='%(asctime)s %(message)s',
level=logging.INFO, filename='logs/prepare_model.log')
torch.random.manual_seed(seed)
prepare_model(llama2_path=llama2_model_path, frozen_path=frozen_model_path, compute_dtype=compute_dtype,
offload_location=offload_to, lora_rank=lora_rank, frozen_dtype=frozen_dtype)
offload_location=offload_to, lora_rank=lora_rank, frozen_dtype=frozen_dtype)

View File

@ -1,19 +0,0 @@
# loads model in original llama2 format and saves to another folder in sequential format
import torch
import logging
from loader import prepare_model
seed = 54321
device = 'mps' # mps for macbooks
offload_to = 'disk'
lora_rank = 4
llama2_model_path = '../llama-2-7b'
served_model_path = '../llama7b_f16/'
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.DEBUG)
torch.random.manual_seed(seed)
prepare_model(llama2_path=llama2_model_path, frozen_path=served_model_path, compute_dtype=compute_dtype, offload_location=offload_to, lora_rank=lora_rank, frozen_dtype=torch.float16).to(device).to(compute_dtype)

View File

@ -5,7 +5,7 @@ import os
from loader import load_frozen
from utils import Tokenizer, greedy_gen
from conf_fp32 import *
from conf_fp16 import *
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.DEBUG)
@ -18,6 +18,8 @@ model = load_frozen(frozen_model_path, dropout=0.0, lora_rank=4, frozen_dtype=fr
if lora_weights is not None:
logging.debug(model.load_state_dict(torch.load(lora_weights), strict=False))
logging.info('Model loaded.')
prompt = 'Cubestat reports the following metrics: '
greedy_gen(model, tokenizer, device, prompt, max_new_tokens=30)