slowllama/loader.py

203 lines
7.5 KiB
Python

import torch
import os
import json
import gc
import glob
import logging
import shutil
from model_config import ModelArgs
from blackbox_model import Transformer
from utils import device_supports_dtype
# how are weights sharded in llama2 - by rows or columns
join_dim = {
'wq': 0,
'wk': 0,
'wv': 0,
'wo': 1,
'w1': 0,
'w2': 1,
'w3': 0,
'output': 0,
'tok_embeddings': 1,
}
def get_subset(title, weight_subset, index):
if title in join_dim.keys():
jdim = join_dim[title]
step = weight_subset.shape[jdim]
subset = (slice(step * index, step * (index + 1)), slice(None))
if jdim == 1:
subset = (subset[1], subset[0])
return subset
else:
return tuple(slice(None) for _ in range(len(weight_subset.shape)))
def get_w_subset(title, weight, shards, shard):
if title in join_dim.keys():
jdim = join_dim[title]
step = weight.shape[jdim] // shards
subset = (slice(step * shard, step * (shard + 1)), slice(None))
if jdim == 1:
subset = (subset[1], subset[0])
return subset
else:
return tuple(slice(None) for _ in range(len(weight.shape)))
def apply_subset(module, weight_subset, checkpoint_index, title):
with torch.no_grad():
idx_subset = get_subset(title, weight_subset, checkpoint_index)
module.weight[idx_subset] = weight_subset
def prepare_model(llama2_path, frozen_path, **kwargs):
params_path = os.path.join(llama2_path, 'params.json')
with open(params_path, 'r') as conf_file:
config = json.loads(conf_file.read())
for k, v in kwargs.items():
config[k] = v
args = ModelArgs(**config)
args.vocab_size = args.vocab_size_override
args.served_model_path = frozen_path
logging.info('creating model instance')
model = Transformer(args)
paths = sorted(glob.glob(f'{llama2_path}/consolidated.*.pth'))
shards = len(paths)
for ci, checkpoint_path in enumerate(paths):
logging.info(f'prepare_model: processing checkpoint {ci} out of {shards}')
checkpoint = torch.load(checkpoint_path, map_location='cpu')
for i, layer in enumerate(model.layers):
prefix = f'layers.{i}.'
#block = layer.loaded_inner()
for title, submodule in layer.named_modules():
if hasattr(submodule, 'weight'):
full_path = f'{prefix}{title}.weight'
weight_subset = checkpoint[full_path]
apply_subset(submodule, weight_subset, ci, title)
del checkpoint[full_path]
gc.collect()
prefix = f'layers.{i}.attention.'
attention = layer.attention.loaded_inner()
for title, submodule in attention.named_modules():
if hasattr(submodule, 'weight'):
if 'attention_norm' in title:
full_path = f'layers.{i}.attention_norm.weight'
else:
full_path = f'{prefix}{title}.weight'
weight_subset = checkpoint[full_path]
apply_subset(submodule, weight_subset, ci, title)
del checkpoint[full_path]
gc.collect()
layer.attention.save(attention)
prefix = f'layers.{i}.feed_forward.'
feed_forward = layer.feed_forward.loaded_inner()
for title, submodule in feed_forward.named_modules():
if hasattr(submodule, 'weight'):
if 'ffn_norm' in title:
full_path = f'layers.{i}.ffn_norm.weight'
else:
full_path = f'{prefix}{title}.weight'
weight_subset = checkpoint[full_path]
apply_subset(submodule, weight_subset, ci, title)
del checkpoint[full_path]
gc.collect()
layer.feed_forward.save(feed_forward)
logging.info(f'prepare_model: updating layer {i} out of {len(model.layers)}')
#layer.save(block)
# now repeat for other submodules: output, embeddings and norm
title = 'output'
block = model.output.loaded_inner()
apply_subset(block, checkpoint[f'{title}.weight'], ci, title)
logging.info(f'prepare_model: updating output layer')
model.output.save(block)
title = 'tok_embeddings'
block = model.tok_embeddings.loaded_inner()
apply_subset(block, checkpoint[f'{title}.weight'], ci, title)
logging.info(f'prepare_model: updating token embeddings')
model.tok_embeddings.save(block)
# norm left
apply_subset(model.norm, checkpoint['norm.weight'], ci, None)
# we also need to copy:
# - params.json
# - model dict itself (norm + Lora)
# - tokenizer?'
shutil.copy(params_path, os.path.join(frozen_path, 'params.json'))
shutil.copy(os.path.join(llama2_path, 'tokenizer.model'), os.path.join(frozen_path, 'tokenizer.model'))
torch.save(model.to(args.frozen_dtype).state_dict(), os.path.join(frozen_path, 'model.pth'))
return model
def load_frozen(path, **kwargs):
logging.info(f'loading sequential model from {path}')
params_path = os.path.join(path, 'params.json')
with open(params_path, 'r') as conf_file:
config = json.loads(conf_file.read())
for k, v in kwargs.items():
config[k] = v
args = ModelArgs(**config)
args.vocab_size = args.vocab_size_override
args.init_frozen = False
args.served_model_path = path
logging.info(f'creating model instance')
model = Transformer(args).to(args.compute_dtype)
logging.info(f'loading model dict')
model.load_state_dict(torch.load(os.path.join(path, 'model.pth')), strict=False)
return model
# this is merging LoRA back to original weights in llama2 format
def add_lora(model_path, lora_path):
lora_weights = torch.load(lora_path, map_location='cpu')
paths = sorted(glob.glob(f'{model_path}/consolidated.*.pth'))
params_path = os.path.join(model_path, 'params.json')
with open(params_path, 'r') as conf_file:
config = json.loads(conf_file.read())
shards = len(paths)
config = ModelArgs(**config)
n_layers = int(config.n_layers)
lora_scale = config.lora_alpha / config.lora_rank
for ci, checkpoint_path in enumerate(paths):
logging.info(f'add_lora: processing checkpoint {ci} out of {shards}')
checkpoint = torch.load(checkpoint_path, map_location='cpu')
for layer in range(n_layers):
logging.info(f'add_lora: processing checkpoint {ci} layer {layer} out of {n_layers}')
for attn_key in ['v', 'q']:
local_path = f'w{attn_key}'
checkpoint_key = f'layers.{layer}.attention.{local_path}.weight'
a_key = f'{attn_key}_lora_{layer}.A.weight'
b_key = f'{attn_key}_lora_{layer}.B.weight'
original_type = lora_weights[b_key].dtype
if device_supports_dtype('cpu', original_type):
lora = lora_weights[b_key].mm(lora_weights[a_key]) * lora_scale
else:
lora = lora_weights[b_key].to(torch.float32).mm(lora_weights[a_key].to(torch.float32)) * lora_scale
lora = lora.to(original_type)
subset = get_w_subset(local_path, lora, shards, ci)
checkpoint[checkpoint_key] = checkpoint[checkpoint_key] + lora[subset].to(torch.bfloat16)
torch.save(checkpoint, checkpoint_path)
del checkpoint
gc.collect()