slowllama: combined update

* make lora params configurable
* better output for finetune prompt completion
* cleanup peak_rss_mb
* readme update
*
This commit is contained in:
Oleksandr Kuvshynov 2023-09-15 09:30:47 -04:00
parent b42dc58859
commit 6b164fb6fb
6 changed files with 20 additions and 81 deletions

View File

@ -86,8 +86,8 @@ First, we need to be able to load a model which requires more RAM than we have a
Doing forward path is easy - we just load modules when we need and pass the output forward.
Backward pass is a little more tricky, in a way we have to run forward pass twice. The way it's [currently implemented](https://github.com/okuvshynov/slowllama/blob/main/blackbox_model.py#L351) is:
1. Do a forward pass while also saving inputs to each offloaded block to the SSD. The goal of the first forward pass is to compute the final loss and cache inputs to each offloaded module.
2. Then, do a manual backward gradient propagation. We start from the last module, re-run each module (forward, to build autograd graph) with the input we cached on step (1) again. After that we run backward pass within that block only, and pass the gradient for the input to the next (previous?) module. As we use LoRA, only LoRA weights are being updated. LoRA weights are not offloaded to disk, always staying on RAM/GPU. Important: we also need to save and restore random number generation state before evaluating each offloaded module. During training we use dropout, and randomly switched off neurons should be the same on both forward passes.
1. Do a forward pass while also saving inputs to each offloaded block to the SSD. The goal of the first forward pass is to compute the final loss and cache inputs to each offloaded block.
2. Then, do a manual backward gradient propagation. We start from the last block, re-run each block once again (forward, to build autograd graph) with the same input we cached on step (1). After that we run backward pass within that block only, and pass the gradient for the input to the next (previous?) block. As we use LoRA, only LoRA gradients are being saved. LoRA weights are not offloaded to disk, always staying on RAM/GPU. Important: we also need to save and restore random number generation state before evaluating each offloaded module. During training we use dropout, and randomly switched off neurons should be the same on both forward passes.
3. After that we run optimizer step on LoRA weights and save them separately if needed.
Original llama2 weights are in bfloat16, but mps backend doesn't support that type natively, so we do computation in float32 instead.

View File

@ -6,6 +6,7 @@
# - simplify init/generation as we only use it for fine-tuning experiments
# - manual backprop
# - support for ffn_dim_multiplier which llama2-70b uses
# - LoRA
import math
from dataclasses import dataclass
@ -28,11 +29,14 @@ class ModelArgs:
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
norm_eps: float = 1e-5
max_seq_len: int = 2048
dropout: float = 0.0
dropout: float = 0.0 # unless we bring back
ffn_dim_multiplier: Optional[float] = None
compute_dtype: torch.dtype = torch.float32
offload_location: str = 'disk' # 'disk' or 'ram'
rope_theta: float = 10000.0
lora_rank: int = 8
lora_alpha: int = 64
lora_dropout: float = 0.05
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float):
@ -221,7 +225,7 @@ class TransformerBlock(nn.Module):
return out
class LoRA(nn.Module):
def __init__(self, original_layer, rank=8, alpha=64, dropout=0.05):
def __init__(self, original_layer, rank, alpha, dropout):
super().__init__()
n, m = original_layer.weight.shape
self.A = nn.Linear(m, rank, bias=False)
@ -250,8 +254,8 @@ class Transformer(nn.Module):
self.lora_layers = []
for layer_id in range(params.n_layers):
block = TransformerBlock(layer_id, params)
q_lora = LoRA(block.attention.wq).to(params.compute_dtype)
v_lora = LoRA(block.attention.wv).to(params.compute_dtype)
q_lora = LoRA(block.attention.wq, rank=params.lora_rank, alpha=params.lora_alpha, dropout=params.lora_dropout).to(params.compute_dtype)
v_lora = LoRA(block.attention.wv, rank=params.lora_rank, alpha=params.lora_alpha, dropout=params.lora_dropout).to(params.compute_dtype)
self.lora_layers.append({ 'q_lora': q_lora, 'v_lora': v_lora})
self.add_module(f'q_lora_{layer_id}', q_lora)
self.add_module(f'v_lora_{layer_id}', v_lora)
@ -341,7 +345,7 @@ class Transformer(nn.Module):
norm_out = norm_out.detach()
norm_out.requires_grad = True
# TODO: micro-optimization: as output is last layer, we can skip loading it second time
# TODO: micro-optimization: as output is last layer, we can skip loading and running it second time
logits = self.output(norm_out)
logits = logits.detach()
logits.requires_grad = True

View File

@ -15,7 +15,6 @@ seed = 54321
iters = 1000
device = 'mps' # mps for macbooks
seq_len = 128
dropout = 0.01
batch_size = 16
lr = 1e-4
offload_to = 'disk'
@ -30,7 +29,7 @@ gen_tokens = 32
log_lora_grad = False
log_lora_weight = True
model_path = '/Volumes/LLAMAS//llama-2-70b'
model_path = '../llama-2-7b'
finetune_file = './README.md'
prompt = 'slowllama is a '
@ -41,7 +40,7 @@ with open(finetune_file) as f:
tokenizer_path = os.path.join(model_path, 'tokenizer.model')
tokenizer = Tokenizer(tokenizer_path)
def greedy_gen(prompt, max_new_tokens=50):
def greedy_gen(prompt, iter, max_new_tokens=50):
tokens = torch.tensor(tokenizer.encode(prompt, True, False)).view(1, -1).to(device)
model.eval()
for _ in range(max_new_tokens):
@ -49,11 +48,12 @@ def greedy_gen(prompt, max_new_tokens=50):
logits = logits[:, -1, :]
logits_top, next_tokens = torch.topk(logits, k=25, dim=-1)
next_token = next_tokens[0, 0].view(1, 1)
logging.info(f'next tokens: {logits_top} {next_tokens} {tokenizer.decode(next_tokens.tolist())}')
logging.info(f'next token: {next_token}')
#logging.info(f'next tokens: {logits_top} {next_tokens} {tokenizer.decode(next_tokens.tolist())}')
tokens = torch.cat((tokens, next_token), dim=1)
for i, output in enumerate(tokens):
logging.info(f'{i} - {tokenizer.decode(output.tolist())}')
for output in tokens:
logging.info(f'after {iter} iterations: {tokenizer.decode(output.tolist())}')
if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO, filename='finetune.log')
@ -66,7 +66,7 @@ if __name__ == '__main__':
logging.info(f'loaded dataset: {len(tokens)} tokens')
model = load_llama2(model_path, dropout=dropout, compute_dtype=compute_dtype, offload_location=offload_to).to(device).to(compute_dtype)
model = load_llama2(model_path, compute_dtype=compute_dtype, offload_location=offload_to).to(device).to(compute_dtype)
def get_batch(batch_size):
index = torch.randint(len(tokens) - seq_len, (batch_size,))
@ -82,7 +82,7 @@ if __name__ == '__main__':
X, y = get_batch(batch_size)
opt.zero_grad()
if i % eval_period == 0:
greedy_gen(prompt, max_new_tokens=gen_tokens)
greedy_gen(prompt, i, max_new_tokens=gen_tokens)
# both forward and backward passes are here.
# returned loss is a scalar, not variable
logits, loss = model.manual_loop(X, y)

View File

@ -7,7 +7,6 @@ import glob
import logging
from blackbox_model import Transformer, ModelArgs
from utils import peak_rss_mb
vocab_size = 32000

View File

@ -1,60 +0,0 @@
import torch
import os
import resource
def device_map(device):
if str(device).startswith('mps'):
return 'mps'
return str(device)
def device_supports_dtype(device, dtype):
try:
tensor = torch.tensor([1.0, 2.0]).to(device).to(dtype)
return True
except TypeError as e:
return False
global_id_auto = 0
def next_id():
global global_id_auto
res = torch.tensor(global_id_auto)
global_id_auto += 1
return res
def intermediate_path(id):
if torch.is_tensor(id):
id = id.item()
folder = f'{os.path.dirname(__file__)}/data'
if not os.path.exists(folder):
os.makedirs(folder)
return f'{folder}/saved_{id}.pt'
def save_rng_state(device='cpu'):
if device == 'cpu':
import torch
return torch.random.get_rng_state()
elif device.startswith('cuda'):
import torch.cuda
return torch.cuda.get_rng_state(device=int(device.split(':')[1]))
elif device.startswith('mps'):
import torch.mps
return torch.mps.get_rng_state()
else:
raise ValueError(f"Unsupported device: {device}")
def restore_rng_state(rng_state, device='cpu'):
if device == 'cpu':
import torch
torch.random.set_rng_state(rng_state)
elif device.startswith('cuda'):
import torch.cuda
torch.cuda.set_rng_state(rng_state, device=int(device.split(':')[1]))
elif device.startswith('mps'):
import torch.mps
torch.mps.set_rng_state(rng_state)
else:
raise ValueError(f"Unsupported device: {device}")
def peak_rss_mb():
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss // (1024 * 1024)

View File

@ -1,6 +1,5 @@
import torch
import os
import resource
def device_map(device):
if str(device).startswith('mps'):
@ -54,7 +53,4 @@ def restore_rng_state(rng_state, device='cpu'):
import torch.mps
torch.mps.set_rng_state(rng_state)
else:
raise ValueError(f"Unsupported device: {device}")
def peak_rss_mb():
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss // (1024 * 1024)
raise ValueError(f"Unsupported device: {device}")