slowllama: fp16 e2e test
This commit is contained in:
parent
018912536b
commit
f23d7bd929
2
conf.py
2
conf.py
|
@ -27,4 +27,4 @@ snapshots_path = 'out'
|
|||
finetune_file = './test_data/cubestat.txt'
|
||||
prompt = 'Cubestat reports the following metrics: '
|
||||
|
||||
llama2_model_path = '../llama-2-7b'
|
||||
llama2_model_path = '../llama-2-7b'
|
||||
|
|
2
e2e7b.sh
2
e2e7b.sh
|
@ -3,4 +3,4 @@
|
|||
python prepare_model.py
|
||||
python test_gen.py
|
||||
python finetune.py
|
||||
python test_gen.py ./out/state_dict_19.pth
|
||||
python test_gen.py ./out/state_dict_19.pth
|
||||
|
|
|
@ -6,7 +6,7 @@ from loader import load_frozen
|
|||
from plot_lora import log_lora
|
||||
from utils import Tokenizer, greedy_gen
|
||||
|
||||
from conf_fp32 import *
|
||||
from conf_fp16 import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', level=log_level, filename='logs/finetune.log')
|
||||
|
|
|
@ -4,11 +4,11 @@ import torch
|
|||
import logging
|
||||
|
||||
from loader import prepare_model
|
||||
from conf_fp32 import *
|
||||
from conf_fp16 import *
|
||||
|
||||
logging.basicConfig(format='%(asctime)s %(message)s',
|
||||
level=logging.INFO, filename='logs/prepare_model.log')
|
||||
torch.random.manual_seed(seed)
|
||||
|
||||
prepare_model(llama2_path=llama2_model_path, frozen_path=frozen_model_path, compute_dtype=compute_dtype,
|
||||
offload_location=offload_to, lora_rank=lora_rank, frozen_dtype=frozen_dtype)
|
||||
offload_location=offload_to, lora_rank=lora_rank, frozen_dtype=frozen_dtype)
|
||||
|
|
19
test_fp16.py
19
test_fp16.py
|
@ -1,19 +0,0 @@
|
|||
# loads model in original llama2 format and saves to another folder in sequential format
|
||||
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from loader import prepare_model
|
||||
|
||||
seed = 54321
|
||||
device = 'mps' # mps for macbooks
|
||||
offload_to = 'disk'
|
||||
lora_rank = 4
|
||||
|
||||
llama2_model_path = '../llama-2-7b'
|
||||
served_model_path = '../llama7b_f16/'
|
||||
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.DEBUG)
|
||||
torch.random.manual_seed(seed)
|
||||
|
||||
prepare_model(llama2_path=llama2_model_path, frozen_path=served_model_path, compute_dtype=compute_dtype, offload_location=offload_to, lora_rank=lora_rank, frozen_dtype=torch.float16).to(device).to(compute_dtype)
|
|
@ -5,7 +5,7 @@ import os
|
|||
|
||||
from loader import load_frozen
|
||||
from utils import Tokenizer, greedy_gen
|
||||
from conf_fp32 import *
|
||||
from conf_fp16 import *
|
||||
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.DEBUG)
|
||||
|
||||
|
@ -18,6 +18,8 @@ model = load_frozen(frozen_model_path, dropout=0.0, lora_rank=4, frozen_dtype=fr
|
|||
if lora_weights is not None:
|
||||
logging.debug(model.load_state_dict(torch.load(lora_weights), strict=False))
|
||||
|
||||
logging.info('Model loaded.')
|
||||
|
||||
prompt = 'Cubestat reports the following metrics: '
|
||||
|
||||
greedy_gen(model, tokenizer, device, prompt, max_new_tokens=30)
|
||||
|
|
Loading…
Reference in New Issue