rename model-specific files to llama2_

I'll try to add mistral/dbrx support, and they might need a different
logic for loading/eval/backprop
This commit is contained in:
Oleksandr Kuvshynov 2024-03-28 10:14:10 -04:00
parent f055a88bdd
commit 0d8a4cf5dc
9 changed files with 11 additions and 11 deletions

View File

@ -96,11 +96,11 @@ python test_gen.py ./out/state_dict_19.pth
### How does it work?
For all versions the process is roughly the same.
First, we need to be able to load a model which requires more RAM than we have and save it back in sequential format. We create model instance with all large modules' weights offloaded to SSD - all of the transformer blocks, token embeddings and output linear layer. After that we [load model shards one by one](https://github.com/okuvshynov/slowllama/blob/main/loader.py#L69), for each shard iterate over all modules, update corresponding subset of its weights and save it back.
First, we need to be able to load a model which requires more RAM than we have and save it back in sequential format. We create model instance with all large modules' weights offloaded to SSD - all of the transformer blocks, token embeddings and output linear layer. After that we [load model shards one by one](https://github.com/okuvshynov/slowllama/blob/main/llama2_loader.py#L69), for each shard iterate over all modules, update corresponding subset of its weights and save it back.
Doing forward path is easy - we just load modules when we need and pass the output forward.
Backward pass is a little more tricky, in a way we have to run forward pass twice. The way it's [currently implemented](https://github.com/okuvshynov/slowllama/blob/main/blackbox_model.py#L351) is:
Backward pass is a little more tricky, in a way we have to run forward pass twice. The way it's [currently implemented](https://github.com/okuvshynov/slowllama/blob/main/llama2.py#L307) is:
1. Do a forward pass while also saving inputs to each offloaded block to the SSD. The goal of the first forward pass is to compute the final loss and cache inputs to each offloaded block.
2. Then, do a manual backward gradient propagation. We start from the last block, re-run each block once again (forward, to build autograd graph) with the same input we cached on step (1). After that we run backward pass within that block only, and pass the gradient for the input to the next (previous?) block. As we use LoRA, only LoRA gradients are being saved. LoRA weights are not offloaded to disk, always staying on RAM/GPU. Important: we also need to save and restore random number generation state before evaluating each offloaded module. During training we use dropout, and randomly switched off neurons should be the same on both forward passes.
3. After that we run optimizer step on LoRA weights and save them separately if needed.
@ -254,9 +254,9 @@ ANE (Apple's Neural Engine) power consumption.....
Just a few files with no dependencies other than torch, numpy and sentencepiece for tokenizer.
1. [blackbox_model.py](blackbox_model.py) -- model definition and manual backprop implementation. It's based on model.py from [llama2.c](https://github.com/karpathy/llama2.c), also MIT licenced.
1. [llama2.py](llama2.py) -- model definition and manual backprop implementation. It's based on model.py from [llama2.c](https://github.com/karpathy/llama2.c), also MIT licenced.
2. [finetune.py](finetune.py) - script which does the training
3. [loader.py](loader.py) - manual loading/saving of large llama2 models
3. [llama2_loader.py](llama2_loader.py) - manual loading/saving of large llama2 models
4. [utils.py](utils.py) - small utility functions, including saving/loading random generator state for different devices.
5. [test_gen.py](test_gen.py) - greedily complete the prompt. Takes base weights + trained LoRA weights as input. Useful for sanity checks.
6. [blackbox.py](blackbox.py) - module wrapper which offloads the module to disk or main memory.

View File

@ -3,7 +3,7 @@ import torch
import sys
import os
from loader import load_frozen
from llama2_loader import load_frozen
from utils import Tokenizer, greedy_gen2
from conf_fp16 import *

View File

@ -2,7 +2,7 @@ import os
import torch
import logging
from loader import load_frozen
from llama2_loader import load_frozen
from plot_lora import log_lora
from utils import Tokenizer, greedy_gen

View File

@ -3,7 +3,7 @@ import sys
import torch
import logging
from loader import load_frozen
from llama2_loader import load_frozen
from plot_lora import log_lora
from datasets import load_dataset
from utils import Tokenizer, greedy_gen

View File

@ -7,7 +7,7 @@ import logging
import shutil
from model_config import ModelArgs
from blackbox_model import Transformer
from llama2 import Transformer
from utils import device_supports_dtype
# how are weights sharded in llama2 - by rows or columns

View File

@ -3,7 +3,7 @@ import os
import sys
import shutil
from loader import add_lora
from llama2_loader import add_lora
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO, filename='logs/merge_lora.log')

View File

@ -3,7 +3,7 @@
import torch
import logging
from loader import prepare_model
from llama2_loader import prepare_model
from conf_fp16 import *
logging.basicConfig(format='%(asctime)s %(message)s',

View File

@ -3,7 +3,7 @@ import torch
import sys
import os
from loader import load_frozen
from llama2_loader import load_frozen
from utils import Tokenizer, greedy_gen
from conf_fp16 import *