slowllama: configs
This commit is contained in:
parent
66f071b8ee
commit
018912536b
|
@ -6,11 +6,6 @@ import torch
|
|||
from utils import device_map, next_id, device_supports_dtype
|
||||
from model_config import ModelArgs
|
||||
|
||||
# a wrapper around arbitrary module which can save/load inner model to hard drive
|
||||
# we store base weights always as bfloat16 (that's what llama2 uses)
|
||||
# but we need to load and return it as a type we use for computation.
|
||||
# it gets a little more tricky for MPS device because we cannot load bfloat16 there
|
||||
# directly.
|
||||
class BlackboxDisk(torch.nn.Module):
|
||||
def __init__(self, module, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
@ -19,6 +14,7 @@ class BlackboxDisk(torch.nn.Module):
|
|||
self.compute_dtype = args.compute_dtype
|
||||
self.served_model_path = args.served_model_path
|
||||
self.cached_data_path = args.cached_data_path
|
||||
# TODO: can we deduce this from the data itself
|
||||
self.frozen_dtype = args.frozen_dtype
|
||||
if args.init_frozen:
|
||||
torch.save(module.to('cpu').to(self.frozen_dtype), self.frozen_path())
|
||||
|
@ -43,7 +39,7 @@ class BlackboxDisk(torch.nn.Module):
|
|||
return torch.load(self.frozen_path(), map_location=device_map(device)).to(self.compute_dtype)
|
||||
else:
|
||||
res = torch.load(self.frozen_path(), map_location='cpu')
|
||||
return res.to(self.frozen_dtype).to(device_map(device))
|
||||
return res.to(self.compute_dtype).to(device_map(device))
|
||||
|
||||
def save(self, module):
|
||||
torch.save(module.to('cpu').to(self.frozen_dtype), self.frozen_path())
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
import logging
|
||||
|
||||
offload_to = 'disk'
|
||||
device = 'mps'
|
||||
seed = 54321
|
||||
|
||||
lr = 1e-4
|
||||
|
||||
log_lora_grad = False
|
||||
log_lora_weight = True
|
||||
|
||||
lora_rank = 4
|
||||
|
||||
log_level = logging.DEBUG
|
||||
|
||||
# training settings
|
||||
|
||||
iters = 20
|
||||
seq_len = 128
|
||||
batch_size = 16
|
||||
|
||||
eval_before_training = False
|
||||
eval_period = 20
|
||||
gen_tokens = 32
|
||||
|
||||
snapshots_path = 'out'
|
||||
finetune_file = './test_data/cubestat.txt'
|
||||
prompt = 'Cubestat reports the following metrics: '
|
||||
|
||||
llama2_model_path = '../llama-2-7b'
|
|
@ -0,0 +1,8 @@
|
|||
import torch
|
||||
from conf import *
|
||||
|
||||
adamw_eps = 1e-4
|
||||
compute_dtype = torch.float16
|
||||
frozen_dtype = torch.float16
|
||||
|
||||
frozen_model_path = '../llama7b_f16'
|
|
@ -0,0 +1,8 @@
|
|||
import torch
|
||||
from conf import *
|
||||
|
||||
adamw_eps = 1e-8
|
||||
compute_dtype = torch.float32
|
||||
frozen_dtype = torch.bfloat16
|
||||
|
||||
frozen_model_path = '../llama7b'
|
4
e2e7b.sh
4
e2e7b.sh
|
@ -1,6 +1,6 @@
|
|||
# end-to-end test with llama7b.
|
||||
# TODO: improve to clean up last iter and make it take some params
|
||||
python prepare_model.py
|
||||
python test_gen.py ../llama7b mps
|
||||
python test_gen.py
|
||||
python finetune.py
|
||||
python test_gen.py ../llama7b mps ./out/state_dict_19.pth
|
||||
python test_gen.py ./out/state_dict_19.pth
|
36
finetune.py
36
finetune.py
|
@ -6,37 +6,7 @@ from loader import load_frozen
|
|||
from plot_lora import log_lora
|
||||
from utils import Tokenizer, greedy_gen
|
||||
|
||||
# training settings
|
||||
seed = 54321
|
||||
iters = 20
|
||||
device = 'mps' # mps for macbooks
|
||||
seq_len = 128
|
||||
batch_size = 16
|
||||
lr = 1e-4
|
||||
adamw_eps = 1e-4 # need to change as 1e-8 doesn't fit to float16
|
||||
offload_to = 'disk'
|
||||
|
||||
# type used for computation. Might be different from storage type (which is bfloat16)
|
||||
#compute_dtype = torch.float32 # float32 for macbooks
|
||||
#compute_dtype = torch.bfloat16 # bfloat16 for CUDA
|
||||
compute_dtype = torch.float16
|
||||
frozen_dtype = torch.float16
|
||||
|
||||
eval_before_training = False
|
||||
eval_period = 20
|
||||
gen_tokens = 32
|
||||
|
||||
log_lora_grad = False
|
||||
log_lora_weight = True
|
||||
|
||||
model_path = '../llama7b_f16'
|
||||
snapshots_path = 'out'
|
||||
finetune_file = './test_data/cubestat.txt'
|
||||
prompt = 'Cubestat reports the following metrics: '
|
||||
|
||||
lora_rank = 4
|
||||
|
||||
log_level = logging.DEBUG
|
||||
from conf_fp32 import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', level=log_level, filename='logs/finetune.log')
|
||||
|
@ -49,12 +19,12 @@ if __name__ == '__main__':
|
|||
with open(finetune_file) as f:
|
||||
text = f.read()
|
||||
|
||||
tokenizer = Tokenizer(os.path.join(model_path, 'tokenizer.model'))
|
||||
tokenizer = Tokenizer(os.path.join(frozen_model_path, 'tokenizer.model'))
|
||||
tokens = tokenizer.encode(text, True, True)
|
||||
|
||||
logging.info(f'loaded dataset: {len(tokens)} tokens')
|
||||
|
||||
model = load_frozen(model_path, compute_dtype=compute_dtype, lora_rank=lora_rank, frozen_dtype=frozen_dtype).to(device).to(compute_dtype)
|
||||
model = load_frozen(frozen_model_path, compute_dtype=compute_dtype, lora_rank=lora_rank, frozen_dtype=frozen_dtype).to(device).to(compute_dtype)
|
||||
|
||||
def get_batch(batch_size):
|
||||
index = torch.randint(len(tokens) - seq_len, (batch_size,))
|
||||
|
|
16
loader.py
16
loader.py
|
@ -9,8 +9,6 @@ import shutil
|
|||
from model_config import ModelArgs
|
||||
from blackbox_model import Transformer
|
||||
|
||||
vocab_size = 32000
|
||||
|
||||
# how are weights sharded in llama2 - by rows or columns
|
||||
join_dim = {
|
||||
'attention.wq': 0,
|
||||
|
@ -51,17 +49,17 @@ def apply_subset(module, weight_subset, checkpoint_index, title):
|
|||
idx_subset = get_subset(title, weight_subset, checkpoint_index)
|
||||
module.weight[idx_subset] = weight_subset
|
||||
|
||||
def prepare_model(llama2_path, sequential_path, **kwargs):
|
||||
def prepare_model(llama2_path, frozen_path, **kwargs):
|
||||
params_path = os.path.join(llama2_path, 'params.json')
|
||||
with open(params_path, 'r') as conf_file:
|
||||
config = json.loads(conf_file.read())
|
||||
|
||||
config['vocab_size'] = vocab_size
|
||||
for k, v in kwargs.items():
|
||||
config[k] = v
|
||||
|
||||
args = ModelArgs(**config)
|
||||
args.served_model_path = sequential_path
|
||||
args.vocab_size = args.vocab_size_override
|
||||
args.served_model_path = frozen_path
|
||||
|
||||
logging.info('creating model instance')
|
||||
model = Transformer(args)
|
||||
|
@ -107,9 +105,9 @@ def prepare_model(llama2_path, sequential_path, **kwargs):
|
|||
# - params.json
|
||||
# - model dict itself (norm + Lora)
|
||||
# - tokenizer?'
|
||||
shutil.copy(params_path, os.path.join(sequential_path, 'params.json'))
|
||||
shutil.copy(os.path.join(llama2_path, 'tokenizer.model'), os.path.join(sequential_path, 'tokenizer.model'))
|
||||
torch.save(model.to(args.frozen_dtype).state_dict(), os.path.join(sequential_path, 'model.pth'))
|
||||
shutil.copy(params_path, os.path.join(frozen_path, 'params.json'))
|
||||
shutil.copy(os.path.join(llama2_path, 'tokenizer.model'), os.path.join(frozen_path, 'tokenizer.model'))
|
||||
torch.save(model.to(args.frozen_dtype).state_dict(), os.path.join(frozen_path, 'model.pth'))
|
||||
|
||||
return model
|
||||
|
||||
|
@ -119,11 +117,11 @@ def load_frozen(path, **kwargs):
|
|||
with open(params_path, 'r') as conf_file:
|
||||
config = json.loads(conf_file.read())
|
||||
|
||||
config['vocab_size'] = vocab_size
|
||||
for k, v in kwargs.items():
|
||||
config[k] = v
|
||||
|
||||
args = ModelArgs(**config)
|
||||
args.vocab_size = args.vocab_size_override
|
||||
args.init_frozen = False
|
||||
args.served_model_path = path
|
||||
logging.info(f'creating model instance')
|
||||
|
|
|
@ -9,7 +9,6 @@ class ModelArgs:
|
|||
n_layers: int = 32
|
||||
n_heads: int = 32
|
||||
n_kv_heads: Optional[int] = None
|
||||
vocab_size: int = -1
|
||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
norm_eps: float = 1e-5
|
||||
max_seq_len: int = 2048
|
||||
|
@ -24,4 +23,6 @@ class ModelArgs:
|
|||
served_model_path: str = '' # relative path by default
|
||||
cached_data_path: str = '' # relative path by default
|
||||
init_frozen: bool = True
|
||||
frozen_dtype: torch.dtype = torch.bfloat16
|
||||
frozen_dtype: torch.dtype = torch.bfloat16
|
||||
vocab_size: int = 32000
|
||||
vocab_size_override: int = 32000
|
|
@ -4,18 +4,11 @@ import torch
|
|||
import logging
|
||||
|
||||
from loader import prepare_model
|
||||
from conf_fp32 import *
|
||||
|
||||
seed = 54321
|
||||
device = 'mps' # mps for macbooks
|
||||
offload_to = 'disk'
|
||||
lora_rank = 4
|
||||
|
||||
compute_dtype = torch.float32
|
||||
|
||||
llama2_model_path = '../llama-2-7b'
|
||||
served_model_path = '../llama7b/'
|
||||
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO, filename='logs/prepare_model.log')
|
||||
logging.basicConfig(format='%(asctime)s %(message)s',
|
||||
level=logging.INFO, filename='logs/prepare_model.log')
|
||||
torch.random.manual_seed(seed)
|
||||
|
||||
prepare_model(llama2_path=llama2_model_path, sequential_path=served_model_path, compute_dtype=compute_dtype, offload_location=offload_to, lora_rank=lora_rank).to(device).to(compute_dtype)
|
||||
prepare_model(llama2_path=llama2_model_path, frozen_path=frozen_model_path, compute_dtype=compute_dtype,
|
||||
offload_location=offload_to, lora_rank=lora_rank, frozen_dtype=frozen_dtype)
|
|
@ -10,12 +10,10 @@ device = 'mps' # mps for macbooks
|
|||
offload_to = 'disk'
|
||||
lora_rank = 4
|
||||
|
||||
compute_dtype = torch.float16
|
||||
|
||||
llama2_model_path = '../llama-2-7b'
|
||||
served_model_path = '../llama7b_f16/'
|
||||
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.DEBUG)
|
||||
torch.random.manual_seed(seed)
|
||||
|
||||
prepare_model(llama2_path=llama2_model_path, sequential_path=served_model_path, compute_dtype=compute_dtype, offload_location=offload_to, lora_rank=lora_rank, frozen_dtype=torch.float16).to(device).to(compute_dtype)
|
||||
prepare_model(llama2_path=llama2_model_path, frozen_path=served_model_path, compute_dtype=compute_dtype, offload_location=offload_to, lora_rank=lora_rank, frozen_dtype=torch.float16).to(device).to(compute_dtype)
|
||||
|
|
11
test_gen.py
11
test_gen.py
|
@ -5,17 +5,16 @@ import os
|
|||
|
||||
from loader import load_frozen
|
||||
from utils import Tokenizer, greedy_gen
|
||||
from conf_fp32 import *
|
||||
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.DEBUG)
|
||||
|
||||
model_path = sys.argv[1]
|
||||
device = sys.argv[2] if len(sys.argv) > 2 else 'cpu'
|
||||
lora_weights = sys.argv[3] if len(sys.argv) > 3 else None
|
||||
|
||||
tokenizer_path = os.path.join(model_path, 'tokenizer.model')
|
||||
lora_weights = sys.argv[1] if len(sys.argv) > 1 else None
|
||||
|
||||
tokenizer_path = os.path.join(frozen_model_path, 'tokenizer.model')
|
||||
tokenizer = Tokenizer(tokenizer_path)
|
||||
|
||||
model = load_frozen(sys.argv[1], dropout=0.0, lora_rank=4, frozen_dtype=torch.float16, compute_dtype=torch.float16).to(device)
|
||||
model = load_frozen(frozen_model_path, dropout=0.0, lora_rank=4, frozen_dtype=frozen_dtype, compute_dtype=compute_dtype).to(device)
|
||||
if lora_weights is not None:
|
||||
logging.debug(model.load_state_dict(torch.load(lora_weights), strict=False))
|
||||
|
||||
|
|
Loading…
Reference in New Issue