# based on model.py from https://github.com/karpathy/llama2.c by Andrej Karpathy, MIT licenced # modifications by okuvshynov include: # - no weight tying # - using blackbox offloadable modules # - simplify init/generation as we only use it for fine-tuning experiments # - manual backprop # - support for ffn_dim_multiplier which llama2-70b uses # - LoRA import logging import math from typing import Optional, Tuple import torch import torch.nn.functional as F from torch import nn from blackbox import BlackboxDisk from utils import save_rng_state, restore_rng_state, device_map, cleanup_cache from model_config import ModelArgs import logging class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore freqs_cos = torch.cos(freqs) # real part freqs_sin = torch.sin(freqs) # imaginary part return freqs_cos, freqs_sin def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(shape) def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # reshape xq and xk to match the complex representation xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) # reshape freqs_cos and freqs_sin for broadcasting freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) # apply rotation using real numbers xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos # flatten last two dimensions xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" bs, slen, n_kv_heads, head_dim = x.shape if n_rep == 1: return x return ( x[:, :, :, None, :] .expand(bs, slen, n_kv_heads, n_rep, head_dim) .reshape(bs, slen, n_kv_heads * n_rep, head_dim) ) class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads self.n_heads = args.n_heads self.n_rep = self.n_heads // self.n_kv_heads self.head_dim = args.dim // args.n_heads # here's where we inject LoRA self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) # here's where we inject LoRA self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) # TODO: probably don't need dropout here as we don't plan to do full finetune # or maybe we do. self.attn_dropout = nn.Dropout(args.dropout) self.resid_dropout = nn.Dropout(args.dropout) self.dropout = args.dropout # use flash attention or a manual implementation? self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') if not self.flash: logging.warn("using slow attention. Flash Attention requires PyTorch >= 2.0") mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) mask = torch.triu(mask, diagonal=1) self.register_buffer("mask", mask) self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) def forward( self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, q_lora: nn.Module, v_lora: nn.Module ): bsz, seqlen, _ = x.shape x_base = x x = self.attention_norm(x) # QKV xq, xk, xv = self.wq(x) + q_lora(x), self.wk(x), self.wv(x) + v_lora(x) xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) # RoPE relative positional embeddings xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) # grouped multiquery attention: expand out keys and values xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_heads, head_dim) xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_heads, head_dim) # make heads into a batch dimension xq = xq.transpose(1, 2) # (bs, n_heads, seqlen, head_dim) xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) # flash implementation if self.flash: output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True) else: # manual implementation scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) assert hasattr(self, 'mask') scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_heads, seqlen, cache_len + seqlen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = self.attn_dropout(scores) output = torch.matmul(scores, xv) # (bs, n_heads, seqlen, head_dim) # restore time as batch dimension and concat heads output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) # final projection into the residual stream output = self.wo(output) output = self.resid_dropout(output) return x_base + output class FeedForward(nn.Module): def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float, ffn_dim_multiplier: Optional[float], args: ModelArgs): super().__init__() hidden_dim = int(2 * hidden_dim / 3) if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) self.dropout = nn.Dropout(dropout) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) def forward(self, x): x_base = x x = self.ffn_norm(x) return x_base + self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) class TransformerBlock(nn.Module): def __init__(self, layer_id: int, args: ModelArgs): super().__init__() self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads self.attention = BlackboxDisk(Attention(args), args) self.feed_forward = BlackboxDisk(FeedForward( dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of, dropout=args.dropout, ffn_dim_multiplier=args.ffn_dim_multiplier, args=args ), args) self.layer_id = layer_id def forward(self, x, freqs_cos, freqs_sin, lora_q, lora_v): h = self.attention(x, freqs_cos, freqs_sin, lora_q, lora_v) out = self.feed_forward(h) return out class LoRA(nn.Module): def __init__(self, original_layer, rank, alpha, dropout): super().__init__() n, m = original_layer.weight.shape self.A = nn.Linear(m, rank, bias=False) self.B = nn.Linear(rank, n, bias=False) nn.init.zeros_(self.B.weight) self.dropout = nn.Dropout(dropout) self.scale = alpha / rank # return matrix to add to original weight def expanded(self): res = self.B.weight.mm(self.A.weight) * self.scale return res def forward(self, x): return self.dropout(self.B(self.A(x))) * self.scale class Transformer(nn.Module): def __init__(self, params: ModelArgs): super().__init__() self.params = params self.vocab_size = params.vocab_size self.n_layers = params.n_layers self.tok_embeddings = BlackboxDisk(nn.Embedding(params.vocab_size, params.dim), params) self.dropout = nn.Dropout(params.dropout) self.layers = torch.nn.ModuleList() # we create LoRA adapters separately. As we don't want to load/save them continously self.lora_layers = [] for layer_id in range(params.n_layers): block = TransformerBlock(layer_id, params) # TODO: remove this one attn = block.attention.loaded_inner() q_lora = LoRA(attn.wq, rank=params.lora_rank, alpha=params.lora_alpha, dropout=params.lora_dropout).to(params.compute_dtype) v_lora = LoRA(attn.wv, rank=params.lora_rank, alpha=params.lora_alpha, dropout=params.lora_dropout).to(params.compute_dtype) self.lora_layers.append({ 'q_lora': q_lora, 'v_lora': v_lora}) self.add_module(f'q_lora_{layer_id}', q_lora) self.add_module(f'v_lora_{layer_id}', v_lora) self.layers.append(block) logging.debug(f'created transformer block {layer_id}') self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.norm.requires_grad = False self.output = BlackboxDisk(nn.Linear(params.dim, params.vocab_size, bias=False), params) # some useful precompute for the RoPE relative positional embeddings freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len, theta=params.rope_theta) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) def forward(self, tokens: torch.Tensor) -> torch.Tensor: _bsz, seqlen = tokens.shape # dummy input to force gradient propagation to blackbox modules h = self.tok_embeddings(tokens) h = self.dropout(h) freqs_cos = self.freqs_cos[:seqlen] freqs_sin = self.freqs_sin[:seqlen] for layer, lora in zip(self.layers, self.lora_layers): h = layer(h, freqs_cos, freqs_sin, lora['q_lora'], lora['v_lora']) h = self.norm(h) return self.output(h[:, [-1], :]) def backprop_w_lora(self, blackbox_module, output_grad, *args): device = output_grad.device module = blackbox_module.load(device) # we use LoRA and only updated attached low-rank modules # no part of original model is getting any updates, so no need for gradient for param in module.parameters(): param.requires_grad = False input = blackbox_module.load_input(device) input.requires_grad = True output = module(input, *args) output.backward(output_grad) return input.grad if input.requires_grad else None # this is a manual implementation on forward/backward passes def manual_loop(self, tokens, targets): logging.log(level=logging.DEBUG, msg=f'starting manual loop') device = device_map(tokens.device) embd_out = self.tok_embeddings(tokens) embd_out = embd_out.detach() embd_out.requires_grad = True logging.log(level=logging.DEBUG, msg=f'done embedding') _, seqlen = tokens.shape freqs_cos = self.freqs_cos[:seqlen] freqs_sin = self.freqs_sin[:seqlen] current = self.dropout(embd_out) del embd_out rng_before = [] for i, (layer, lora) in enumerate(zip(self.layers, self.lora_layers)): rng_before.append(save_rng_state(device)) current = layer(current, freqs_cos, freqs_sin, lora['q_lora'], lora['v_lora']) logging.log(level=logging.DEBUG, msg=f'forward: transformer block {i} done') current = current.detach() current.requires_grad = True norm_out = self.norm(current) norm_out = norm_out.detach() norm_out.requires_grad = True # TODO: micro-optimization: as output is last layer, we can skip loading and running it second time logging.log(level=logging.DEBUG, msg=f'output layer') logits = self.output(norm_out) del norm_out logging.log(level=logging.DEBUG, msg=f'output layer done') if (self.params.compute_dtype != torch.float32): logits = logits.to(torch.float32) logits = logits.detach() logits.requires_grad = True loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) logging.log(level=logging.DEBUG, msg=f'forward: computed loss') loss.backward() norm_out_grad = self.backprop_w_lora(self.output, logits.grad.to(self.params.compute_dtype)) del logits logging.log(level=logging.DEBUG, msg=f'combined: output layer done') norm_out2 = self.norm(current) norm_out2.backward(norm_out_grad) del norm_out_grad del norm_out2 last_grad = current.grad del current for i, (layer, rng_state, lora) in enumerate(zip(reversed(self.layers), reversed(rng_before), reversed(self.lora_layers))): cleanup_cache(device) restore_rng_state(rng_state, device=device) # first, do feed_forward last_grad = self.backprop_w_lora(layer.feed_forward, last_grad) # now, do attention cleanup_cache(device) last_grad = self.backprop_w_lora(layer.attention, last_grad, freqs_cos, freqs_sin, lora['q_lora'], lora['v_lora']) logging.log(level=logging.DEBUG, msg=f'combined: transformer block {i} done') # no need to backpropagate through embeddings no LoRA layers there. return loss.item()