slowllama: combined update
* make lora params configurable * better output for finetune prompt completion * cleanup peak_rss_mb * readme update *
This commit is contained in:
parent
b42dc58859
commit
6b164fb6fb
|
@ -86,8 +86,8 @@ First, we need to be able to load a model which requires more RAM than we have a
|
|||
Doing forward path is easy - we just load modules when we need and pass the output forward.
|
||||
|
||||
Backward pass is a little more tricky, in a way we have to run forward pass twice. The way it's [currently implemented](https://github.com/okuvshynov/slowllama/blob/main/blackbox_model.py#L351) is:
|
||||
1. Do a forward pass while also saving inputs to each offloaded block to the SSD. The goal of the first forward pass is to compute the final loss and cache inputs to each offloaded module.
|
||||
2. Then, do a manual backward gradient propagation. We start from the last module, re-run each module (forward, to build autograd graph) with the input we cached on step (1) again. After that we run backward pass within that block only, and pass the gradient for the input to the next (previous?) module. As we use LoRA, only LoRA weights are being updated. LoRA weights are not offloaded to disk, always staying on RAM/GPU. Important: we also need to save and restore random number generation state before evaluating each offloaded module. During training we use dropout, and randomly switched off neurons should be the same on both forward passes.
|
||||
1. Do a forward pass while also saving inputs to each offloaded block to the SSD. The goal of the first forward pass is to compute the final loss and cache inputs to each offloaded block.
|
||||
2. Then, do a manual backward gradient propagation. We start from the last block, re-run each block once again (forward, to build autograd graph) with the same input we cached on step (1). After that we run backward pass within that block only, and pass the gradient for the input to the next (previous?) block. As we use LoRA, only LoRA gradients are being saved. LoRA weights are not offloaded to disk, always staying on RAM/GPU. Important: we also need to save and restore random number generation state before evaluating each offloaded module. During training we use dropout, and randomly switched off neurons should be the same on both forward passes.
|
||||
3. After that we run optimizer step on LoRA weights and save them separately if needed.
|
||||
|
||||
Original llama2 weights are in bfloat16, but mps backend doesn't support that type natively, so we do computation in float32 instead.
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
# - simplify init/generation as we only use it for fine-tuning experiments
|
||||
# - manual backprop
|
||||
# - support for ffn_dim_multiplier which llama2-70b uses
|
||||
# - LoRA
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
@ -28,11 +29,14 @@ class ModelArgs:
|
|||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
norm_eps: float = 1e-5
|
||||
max_seq_len: int = 2048
|
||||
dropout: float = 0.0
|
||||
dropout: float = 0.0 # unless we bring back
|
||||
ffn_dim_multiplier: Optional[float] = None
|
||||
compute_dtype: torch.dtype = torch.float32
|
||||
offload_location: str = 'disk' # 'disk' or 'ram'
|
||||
rope_theta: float = 10000.0
|
||||
lora_rank: int = 8
|
||||
lora_alpha: int = 64
|
||||
lora_dropout: float = 0.05
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float):
|
||||
|
@ -221,7 +225,7 @@ class TransformerBlock(nn.Module):
|
|||
return out
|
||||
|
||||
class LoRA(nn.Module):
|
||||
def __init__(self, original_layer, rank=8, alpha=64, dropout=0.05):
|
||||
def __init__(self, original_layer, rank, alpha, dropout):
|
||||
super().__init__()
|
||||
n, m = original_layer.weight.shape
|
||||
self.A = nn.Linear(m, rank, bias=False)
|
||||
|
@ -250,8 +254,8 @@ class Transformer(nn.Module):
|
|||
self.lora_layers = []
|
||||
for layer_id in range(params.n_layers):
|
||||
block = TransformerBlock(layer_id, params)
|
||||
q_lora = LoRA(block.attention.wq).to(params.compute_dtype)
|
||||
v_lora = LoRA(block.attention.wv).to(params.compute_dtype)
|
||||
q_lora = LoRA(block.attention.wq, rank=params.lora_rank, alpha=params.lora_alpha, dropout=params.lora_dropout).to(params.compute_dtype)
|
||||
v_lora = LoRA(block.attention.wv, rank=params.lora_rank, alpha=params.lora_alpha, dropout=params.lora_dropout).to(params.compute_dtype)
|
||||
self.lora_layers.append({ 'q_lora': q_lora, 'v_lora': v_lora})
|
||||
self.add_module(f'q_lora_{layer_id}', q_lora)
|
||||
self.add_module(f'v_lora_{layer_id}', v_lora)
|
||||
|
@ -341,7 +345,7 @@ class Transformer(nn.Module):
|
|||
norm_out = norm_out.detach()
|
||||
norm_out.requires_grad = True
|
||||
|
||||
# TODO: micro-optimization: as output is last layer, we can skip loading it second time
|
||||
# TODO: micro-optimization: as output is last layer, we can skip loading and running it second time
|
||||
logits = self.output(norm_out)
|
||||
logits = logits.detach()
|
||||
logits.requires_grad = True
|
||||
|
|
16
finetune.py
16
finetune.py
|
@ -15,7 +15,6 @@ seed = 54321
|
|||
iters = 1000
|
||||
device = 'mps' # mps for macbooks
|
||||
seq_len = 128
|
||||
dropout = 0.01
|
||||
batch_size = 16
|
||||
lr = 1e-4
|
||||
offload_to = 'disk'
|
||||
|
@ -30,7 +29,7 @@ gen_tokens = 32
|
|||
log_lora_grad = False
|
||||
log_lora_weight = True
|
||||
|
||||
model_path = '/Volumes/LLAMAS//llama-2-70b'
|
||||
model_path = '../llama-2-7b'
|
||||
finetune_file = './README.md'
|
||||
prompt = 'slowllama is a '
|
||||
|
||||
|
@ -41,7 +40,7 @@ with open(finetune_file) as f:
|
|||
tokenizer_path = os.path.join(model_path, 'tokenizer.model')
|
||||
tokenizer = Tokenizer(tokenizer_path)
|
||||
|
||||
def greedy_gen(prompt, max_new_tokens=50):
|
||||
def greedy_gen(prompt, iter, max_new_tokens=50):
|
||||
tokens = torch.tensor(tokenizer.encode(prompt, True, False)).view(1, -1).to(device)
|
||||
model.eval()
|
||||
for _ in range(max_new_tokens):
|
||||
|
@ -49,11 +48,12 @@ def greedy_gen(prompt, max_new_tokens=50):
|
|||
logits = logits[:, -1, :]
|
||||
logits_top, next_tokens = torch.topk(logits, k=25, dim=-1)
|
||||
next_token = next_tokens[0, 0].view(1, 1)
|
||||
logging.info(f'next tokens: {logits_top} {next_tokens} {tokenizer.decode(next_tokens.tolist())}')
|
||||
logging.info(f'next token: {next_token}')
|
||||
#logging.info(f'next tokens: {logits_top} {next_tokens} {tokenizer.decode(next_tokens.tolist())}')
|
||||
tokens = torch.cat((tokens, next_token), dim=1)
|
||||
|
||||
for i, output in enumerate(tokens):
|
||||
logging.info(f'{i} - {tokenizer.decode(output.tolist())}')
|
||||
for output in tokens:
|
||||
logging.info(f'after {iter} iterations: {tokenizer.decode(output.tolist())}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO, filename='finetune.log')
|
||||
|
@ -66,7 +66,7 @@ if __name__ == '__main__':
|
|||
|
||||
logging.info(f'loaded dataset: {len(tokens)} tokens')
|
||||
|
||||
model = load_llama2(model_path, dropout=dropout, compute_dtype=compute_dtype, offload_location=offload_to).to(device).to(compute_dtype)
|
||||
model = load_llama2(model_path, compute_dtype=compute_dtype, offload_location=offload_to).to(device).to(compute_dtype)
|
||||
|
||||
def get_batch(batch_size):
|
||||
index = torch.randint(len(tokens) - seq_len, (batch_size,))
|
||||
|
@ -82,7 +82,7 @@ if __name__ == '__main__':
|
|||
X, y = get_batch(batch_size)
|
||||
opt.zero_grad()
|
||||
if i % eval_period == 0:
|
||||
greedy_gen(prompt, max_new_tokens=gen_tokens)
|
||||
greedy_gen(prompt, i, max_new_tokens=gen_tokens)
|
||||
# both forward and backward passes are here.
|
||||
# returned loss is a scalar, not variable
|
||||
logits, loss = model.manual_loop(X, y)
|
||||
|
|
|
@ -7,7 +7,6 @@ import glob
|
|||
import logging
|
||||
|
||||
from blackbox_model import Transformer, ModelArgs
|
||||
from utils import peak_rss_mb
|
||||
|
||||
vocab_size = 32000
|
||||
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
import torch
|
||||
import os
|
||||
import resource
|
||||
|
||||
def device_map(device):
|
||||
if str(device).startswith('mps'):
|
||||
return 'mps'
|
||||
return str(device)
|
||||
|
||||
def device_supports_dtype(device, dtype):
|
||||
try:
|
||||
tensor = torch.tensor([1.0, 2.0]).to(device).to(dtype)
|
||||
return True
|
||||
except TypeError as e:
|
||||
return False
|
||||
|
||||
global_id_auto = 0
|
||||
|
||||
def next_id():
|
||||
global global_id_auto
|
||||
res = torch.tensor(global_id_auto)
|
||||
global_id_auto += 1
|
||||
return res
|
||||
|
||||
def intermediate_path(id):
|
||||
if torch.is_tensor(id):
|
||||
id = id.item()
|
||||
folder = f'{os.path.dirname(__file__)}/data'
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
return f'{folder}/saved_{id}.pt'
|
||||
|
||||
def save_rng_state(device='cpu'):
|
||||
if device == 'cpu':
|
||||
import torch
|
||||
return torch.random.get_rng_state()
|
||||
elif device.startswith('cuda'):
|
||||
import torch.cuda
|
||||
return torch.cuda.get_rng_state(device=int(device.split(':')[1]))
|
||||
elif device.startswith('mps'):
|
||||
import torch.mps
|
||||
return torch.mps.get_rng_state()
|
||||
else:
|
||||
raise ValueError(f"Unsupported device: {device}")
|
||||
|
||||
def restore_rng_state(rng_state, device='cpu'):
|
||||
if device == 'cpu':
|
||||
import torch
|
||||
torch.random.set_rng_state(rng_state)
|
||||
elif device.startswith('cuda'):
|
||||
import torch.cuda
|
||||
torch.cuda.set_rng_state(rng_state, device=int(device.split(':')[1]))
|
||||
elif device.startswith('mps'):
|
||||
import torch.mps
|
||||
torch.mps.set_rng_state(rng_state)
|
||||
else:
|
||||
raise ValueError(f"Unsupported device: {device}")
|
||||
|
||||
def peak_rss_mb():
|
||||
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss // (1024 * 1024)
|
6
utils.py
6
utils.py
|
@ -1,6 +1,5 @@
|
|||
import torch
|
||||
import os
|
||||
import resource
|
||||
|
||||
def device_map(device):
|
||||
if str(device).startswith('mps'):
|
||||
|
@ -54,7 +53,4 @@ def restore_rng_state(rng_state, device='cpu'):
|
|||
import torch.mps
|
||||
torch.mps.set_rng_state(rng_state)
|
||||
else:
|
||||
raise ValueError(f"Unsupported device: {device}")
|
||||
|
||||
def peak_rss_mb():
|
||||
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss // (1024 * 1024)
|
||||
raise ValueError(f"Unsupported device: {device}")
|
Loading…
Reference in New Issue