381 lines
14 KiB
Python
381 lines
14 KiB
Python
|
# based on model.py from https://github.com/karpathy/llama2.c by Andrej Karpathy, MIT licenced
|
||
|
|
||
|
# modifications by okuvshynov include:
|
||
|
# - no weight tying
|
||
|
# - using blackbox offloadable modules
|
||
|
# - simplify init/generation as we only use it for fine-tuning experiments
|
||
|
# - manual backprop
|
||
|
# - support for ffn_dim_multiplier which llama2-70b uses
|
||
|
# - LoRA
|
||
|
|
||
|
import logging
|
||
|
import math
|
||
|
|
||
|
from typing import Optional, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from torch import nn
|
||
|
|
||
|
from blackbox import BlackboxDisk
|
||
|
from utils import save_rng_state, restore_rng_state, device_map, cleanup_cache
|
||
|
from model_config import ModelArgs
|
||
|
|
||
|
import logging
|
||
|
|
||
|
|
||
|
class RMSNorm(torch.nn.Module):
|
||
|
def __init__(self, dim: int, eps: float):
|
||
|
super().__init__()
|
||
|
self.eps = eps
|
||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||
|
|
||
|
def _norm(self, x):
|
||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||
|
|
||
|
def forward(self, x):
|
||
|
output = self._norm(x.float()).type_as(x)
|
||
|
return output * self.weight
|
||
|
|
||
|
|
||
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||
|
t = torch.arange(end, device=freqs.device) # type: ignore
|
||
|
freqs = torch.outer(t, freqs).float() # type: ignore
|
||
|
freqs_cos = torch.cos(freqs) # real part
|
||
|
freqs_sin = torch.sin(freqs) # imaginary part
|
||
|
return freqs_cos, freqs_sin
|
||
|
|
||
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||
|
ndim = x.ndim
|
||
|
assert 0 <= 1 < ndim
|
||
|
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||
|
return freqs_cis.view(shape)
|
||
|
|
||
|
def apply_rotary_emb(
|
||
|
xq: torch.Tensor,
|
||
|
xk: torch.Tensor,
|
||
|
freqs_cos: torch.Tensor,
|
||
|
freqs_sin: torch.Tensor
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
|
||
|
# reshape xq and xk to match the complex representation
|
||
|
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
|
||
|
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
|
||
|
|
||
|
# reshape freqs_cos and freqs_sin for broadcasting
|
||
|
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
|
||
|
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
|
||
|
|
||
|
# apply rotation using real numbers
|
||
|
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
|
||
|
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
|
||
|
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
|
||
|
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
|
||
|
|
||
|
# flatten last two dimensions
|
||
|
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
|
||
|
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
|
||
|
|
||
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||
|
|
||
|
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||
|
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||
|
bs, slen, n_kv_heads, head_dim = x.shape
|
||
|
if n_rep == 1:
|
||
|
return x
|
||
|
return (
|
||
|
x[:, :, :, None, :]
|
||
|
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||
|
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||
|
)
|
||
|
|
||
|
class Attention(nn.Module):
|
||
|
def __init__(self, args: ModelArgs):
|
||
|
super().__init__()
|
||
|
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||
|
self.n_heads = args.n_heads
|
||
|
self.n_rep = self.n_heads // self.n_kv_heads
|
||
|
self.head_dim = args.dim // args.n_heads
|
||
|
|
||
|
# here's where we inject LoRA
|
||
|
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||
|
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||
|
|
||
|
# here's where we inject LoRA
|
||
|
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||
|
|
||
|
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||
|
|
||
|
# TODO: probably don't need dropout here as we don't plan to do full finetune
|
||
|
# or maybe we do.
|
||
|
self.attn_dropout = nn.Dropout(args.dropout)
|
||
|
self.resid_dropout = nn.Dropout(args.dropout)
|
||
|
self.dropout = args.dropout
|
||
|
|
||
|
# use flash attention or a manual implementation?
|
||
|
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||
|
if not self.flash:
|
||
|
logging.warn("using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||
|
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||
|
mask = torch.triu(mask, diagonal=1)
|
||
|
self.register_buffer("mask", mask)
|
||
|
|
||
|
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
x: torch.Tensor,
|
||
|
freqs_cos: torch.Tensor,
|
||
|
freqs_sin: torch.Tensor,
|
||
|
q_lora: nn.Module,
|
||
|
v_lora: nn.Module
|
||
|
):
|
||
|
bsz, seqlen, _ = x.shape
|
||
|
|
||
|
x_base = x
|
||
|
x = self.attention_norm(x)
|
||
|
|
||
|
# QKV
|
||
|
xq, xk, xv = self.wq(x) + q_lora(x), self.wk(x), self.wv(x) + v_lora(x)
|
||
|
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||
|
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||
|
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
|
||
|
|
||
|
# RoPE relative positional embeddings
|
||
|
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
|
||
|
|
||
|
# grouped multiquery attention: expand out keys and values
|
||
|
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_heads, head_dim)
|
||
|
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_heads, head_dim)
|
||
|
|
||
|
# make heads into a batch dimension
|
||
|
xq = xq.transpose(1, 2) # (bs, n_heads, seqlen, head_dim)
|
||
|
xk = xk.transpose(1, 2)
|
||
|
xv = xv.transpose(1, 2)
|
||
|
|
||
|
# flash implementation
|
||
|
if self.flash:
|
||
|
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||
|
else:
|
||
|
# manual implementation
|
||
|
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||
|
assert hasattr(self, 'mask')
|
||
|
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_heads, seqlen, cache_len + seqlen)
|
||
|
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||
|
scores = self.attn_dropout(scores)
|
||
|
output = torch.matmul(scores, xv) # (bs, n_heads, seqlen, head_dim)
|
||
|
|
||
|
# restore time as batch dimension and concat heads
|
||
|
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||
|
|
||
|
# final projection into the residual stream
|
||
|
output = self.wo(output)
|
||
|
output = self.resid_dropout(output)
|
||
|
return x_base + output
|
||
|
|
||
|
|
||
|
class FeedForward(nn.Module):
|
||
|
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float, ffn_dim_multiplier: Optional[float], args: ModelArgs):
|
||
|
super().__init__()
|
||
|
hidden_dim = int(2 * hidden_dim / 3)
|
||
|
if ffn_dim_multiplier is not None:
|
||
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||
|
self.dropout = nn.Dropout(dropout)
|
||
|
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x_base = x
|
||
|
x = self.ffn_norm(x)
|
||
|
return x_base + self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||
|
|
||
|
class TransformerBlock(nn.Module):
|
||
|
def __init__(self, layer_id: int, args: ModelArgs):
|
||
|
super().__init__()
|
||
|
self.n_heads = args.n_heads
|
||
|
self.dim = args.dim
|
||
|
self.head_dim = args.dim // args.n_heads
|
||
|
|
||
|
self.attention = BlackboxDisk(Attention(args), args)
|
||
|
self.feed_forward = BlackboxDisk(FeedForward(
|
||
|
dim=args.dim,
|
||
|
hidden_dim=4 * args.dim,
|
||
|
multiple_of=args.multiple_of,
|
||
|
dropout=args.dropout,
|
||
|
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||
|
args=args
|
||
|
), args)
|
||
|
self.layer_id = layer_id
|
||
|
|
||
|
|
||
|
def forward(self, x, freqs_cos, freqs_sin, lora_q, lora_v):
|
||
|
h = self.attention(x, freqs_cos, freqs_sin, lora_q, lora_v)
|
||
|
out = self.feed_forward(h)
|
||
|
return out
|
||
|
|
||
|
class LoRA(nn.Module):
|
||
|
def __init__(self, original_layer, rank, alpha, dropout):
|
||
|
super().__init__()
|
||
|
n, m = original_layer.weight.shape
|
||
|
self.A = nn.Linear(m, rank, bias=False)
|
||
|
self.B = nn.Linear(rank, n, bias=False)
|
||
|
nn.init.zeros_(self.B.weight)
|
||
|
self.dropout = nn.Dropout(dropout)
|
||
|
self.scale = alpha / rank
|
||
|
|
||
|
# return matrix to add to original weight
|
||
|
def expanded(self):
|
||
|
res = self.B.weight.mm(self.A.weight) * self.scale
|
||
|
return res
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.dropout(self.B(self.A(x))) * self.scale
|
||
|
|
||
|
class Transformer(nn.Module):
|
||
|
def __init__(self, params: ModelArgs):
|
||
|
super().__init__()
|
||
|
self.params = params
|
||
|
self.vocab_size = params.vocab_size
|
||
|
self.n_layers = params.n_layers
|
||
|
|
||
|
self.tok_embeddings = BlackboxDisk(nn.Embedding(params.vocab_size, params.dim), params)
|
||
|
self.dropout = nn.Dropout(params.dropout)
|
||
|
self.layers = torch.nn.ModuleList()
|
||
|
|
||
|
# we create LoRA adapters separately. As we don't want to load/save them continously
|
||
|
self.lora_layers = []
|
||
|
for layer_id in range(params.n_layers):
|
||
|
block = TransformerBlock(layer_id, params)
|
||
|
|
||
|
# TODO: remove this one
|
||
|
attn = block.attention.loaded_inner()
|
||
|
q_lora = LoRA(attn.wq, rank=params.lora_rank, alpha=params.lora_alpha, dropout=params.lora_dropout).to(params.compute_dtype)
|
||
|
v_lora = LoRA(attn.wv, rank=params.lora_rank, alpha=params.lora_alpha, dropout=params.lora_dropout).to(params.compute_dtype)
|
||
|
self.lora_layers.append({ 'q_lora': q_lora, 'v_lora': v_lora})
|
||
|
self.add_module(f'q_lora_{layer_id}', q_lora)
|
||
|
self.add_module(f'v_lora_{layer_id}', v_lora)
|
||
|
self.layers.append(block)
|
||
|
logging.debug(f'created transformer block {layer_id}')
|
||
|
|
||
|
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||
|
self.norm.requires_grad = False
|
||
|
self.output = BlackboxDisk(nn.Linear(params.dim, params.vocab_size, bias=False), params)
|
||
|
|
||
|
# some useful precompute for the RoPE relative positional embeddings
|
||
|
freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len, theta=params.rope_theta)
|
||
|
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
||
|
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
||
|
|
||
|
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
||
|
_bsz, seqlen = tokens.shape
|
||
|
|
||
|
# dummy input to force gradient propagation to blackbox modules
|
||
|
h = self.tok_embeddings(tokens)
|
||
|
h = self.dropout(h)
|
||
|
freqs_cos = self.freqs_cos[:seqlen]
|
||
|
freqs_sin = self.freqs_sin[:seqlen]
|
||
|
|
||
|
for layer, lora in zip(self.layers, self.lora_layers):
|
||
|
h = layer(h, freqs_cos, freqs_sin, lora['q_lora'], lora['v_lora'])
|
||
|
h = self.norm(h)
|
||
|
|
||
|
return self.output(h[:, [-1], :])
|
||
|
|
||
|
def backprop_w_lora(self, blackbox_module, output_grad, *args):
|
||
|
device = output_grad.device
|
||
|
module = blackbox_module.load(device)
|
||
|
|
||
|
# we use LoRA and only updated attached low-rank modules
|
||
|
# no part of original model is getting any updates, so no need for gradient
|
||
|
for param in module.parameters():
|
||
|
param.requires_grad = False
|
||
|
|
||
|
input = blackbox_module.load_input(device)
|
||
|
input.requires_grad = True
|
||
|
|
||
|
output = module(input, *args)
|
||
|
output.backward(output_grad)
|
||
|
|
||
|
return input.grad if input.requires_grad else None
|
||
|
|
||
|
# this is a manual implementation on forward/backward passes
|
||
|
def manual_loop(self, tokens, targets):
|
||
|
logging.log(level=logging.DEBUG, msg=f'starting manual loop')
|
||
|
device = device_map(tokens.device)
|
||
|
|
||
|
embd_out = self.tok_embeddings(tokens)
|
||
|
embd_out = embd_out.detach()
|
||
|
embd_out.requires_grad = True
|
||
|
logging.log(level=logging.DEBUG, msg=f'done embedding')
|
||
|
|
||
|
_, seqlen = tokens.shape
|
||
|
|
||
|
freqs_cos = self.freqs_cos[:seqlen]
|
||
|
freqs_sin = self.freqs_sin[:seqlen]
|
||
|
|
||
|
current = self.dropout(embd_out)
|
||
|
del embd_out
|
||
|
|
||
|
rng_before = []
|
||
|
|
||
|
for i, (layer, lora) in enumerate(zip(self.layers, self.lora_layers)):
|
||
|
rng_before.append(save_rng_state(device))
|
||
|
current = layer(current, freqs_cos, freqs_sin, lora['q_lora'], lora['v_lora'])
|
||
|
logging.log(level=logging.DEBUG, msg=f'forward: transformer block {i} done')
|
||
|
|
||
|
current = current.detach()
|
||
|
current.requires_grad = True
|
||
|
|
||
|
norm_out = self.norm(current)
|
||
|
norm_out = norm_out.detach()
|
||
|
norm_out.requires_grad = True
|
||
|
|
||
|
# TODO: micro-optimization: as output is last layer, we can skip loading and running it second time
|
||
|
logging.log(level=logging.DEBUG, msg=f'output layer')
|
||
|
logits = self.output(norm_out)
|
||
|
del norm_out
|
||
|
|
||
|
logging.log(level=logging.DEBUG, msg=f'output layer done')
|
||
|
|
||
|
if (self.params.compute_dtype != torch.float32):
|
||
|
logits = logits.to(torch.float32)
|
||
|
|
||
|
logits = logits.detach()
|
||
|
logits.requires_grad = True
|
||
|
|
||
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||
|
logging.log(level=logging.DEBUG, msg=f'forward: computed loss')
|
||
|
|
||
|
loss.backward()
|
||
|
|
||
|
norm_out_grad = self.backprop_w_lora(self.output, logits.grad.to(self.params.compute_dtype))
|
||
|
del logits
|
||
|
logging.log(level=logging.DEBUG, msg=f'combined: output layer done')
|
||
|
|
||
|
norm_out2 = self.norm(current)
|
||
|
norm_out2.backward(norm_out_grad)
|
||
|
del norm_out_grad
|
||
|
del norm_out2
|
||
|
|
||
|
last_grad = current.grad
|
||
|
del current
|
||
|
|
||
|
for i, (layer, rng_state, lora) in enumerate(zip(reversed(self.layers), reversed(rng_before), reversed(self.lora_layers))):
|
||
|
cleanup_cache(device)
|
||
|
restore_rng_state(rng_state, device=device)
|
||
|
# first, do feed_forward
|
||
|
last_grad = self.backprop_w_lora(layer.feed_forward, last_grad)
|
||
|
|
||
|
# now, do attention
|
||
|
cleanup_cache(device)
|
||
|
last_grad = self.backprop_w_lora(layer.attention, last_grad, freqs_cos, freqs_sin, lora['q_lora'], lora['v_lora'])
|
||
|
logging.log(level=logging.DEBUG, msg=f'combined: transformer block {i} done')
|
||
|
|
||
|
# no need to backpropagate through embeddings no LoRA layers there.
|
||
|
return loss.item()
|