mirror of https://github.com/vllm-project/vllm
Add LoRA support for Mixtral (#2831)
* add mixtral lora support * formatting * fix incorrectly ported logic * polish tests * minor fixes and refactoring * minor fixes * formatting * rename and remove redundant logic * refactoring * refactoring * minor fix * minor refactoring * fix code smell
This commit is contained in:
parent
317b29de0f
commit
2a543d6efe
|
@ -121,6 +121,11 @@ def sql_lora_files():
|
|||
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mixtral_lora_files():
|
||||
return snapshot_download(repo_id="terrysun/mixtral-lora-adapter")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
||||
cleanup()
|
||||
|
|
|
@ -11,25 +11,35 @@ from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
|||
RowParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA)
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager,
|
||||
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager, LoRAMapping)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
||||
WorkerLoRAManager)
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
|
||||
EMBEDDING_MODULES = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
|
||||
EMBEDDING_PADDING_MODULES = ["lm_head"]
|
||||
|
||||
|
||||
def test_from_lora_tensors(sql_lora_files):
|
||||
tensors = load_file(
|
||||
os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
||||
new_embeddings = load_file(
|
||||
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
|
||||
lora_model = LoRAModel.from_lora_tensors(1,
|
||||
8,
|
||||
16,
|
||||
tensors,
|
||||
"cuda",
|
||||
embeddings=new_embeddings)
|
||||
lora_model = LoRAModel.from_lora_tensors(
|
||||
1,
|
||||
8,
|
||||
16,
|
||||
tensors,
|
||||
"cuda",
|
||||
embeddings=new_embeddings,
|
||||
embedding_modules=EMBEDDING_MODULES,
|
||||
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
|
||||
for module_name, lora in lora_model.loras.items():
|
||||
assert lora.module_name == module_name
|
||||
assert lora.rank == 8
|
||||
|
@ -90,14 +100,11 @@ def create_packed_lora(
|
|||
|
||||
def test_replace_submodules(dist_init, dummy_model):
|
||||
model = dummy_model
|
||||
manager = LoRAModelManager(model,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=8,
|
||||
max_loras=8),
|
||||
lora_target_modules=["dense1", "layer1.dense2"])
|
||||
model.supported_lora_modules = ["dense1", "layer1.dense2"]
|
||||
model.packed_modules_mapping = {}
|
||||
manager = LoRAModelManager(
|
||||
model, 1, 1, 1,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8))
|
||||
model = manager.model
|
||||
|
||||
assert isinstance(model.get_submodule("dense1"),
|
||||
|
@ -111,16 +118,14 @@ def test_replace_submodules(dist_init, dummy_model):
|
|||
|
||||
def test_lora_model_manager(dist_init, dummy_model):
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||
model.packed_modules_mapping = {}
|
||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||
manager = LoRAModelManager(
|
||||
model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
|
||||
lora_target_modules=["dense1", "dense2", "lm_head"])
|
||||
model, 2, 2, 2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
assert manager.add_lora(model_lora1)
|
||||
assert manager.activate_lora(1)
|
||||
|
@ -159,16 +164,14 @@ def test_lora_model_manager(dist_init, dummy_model):
|
|||
|
||||
def test_lora_lru_cache_model_manager(dist_init, dummy_model):
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||
model.packed_modules_mapping = {}
|
||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||
manager = LRUCacheLoRAModelManager(
|
||||
model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
|
||||
lora_target_modules=["dense1", "dense2", "lm_head"])
|
||||
model, 2, 2, 2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
assert manager.add_lora(model_lora1)
|
||||
assert manager.activate_lora(1)
|
||||
|
@ -212,14 +215,15 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
|
|||
# This tests just the LRU cache functionality, everything else is
|
||||
# tested in test_lora_model_manager
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||
model.packed_modules_mapping = {}
|
||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||
model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
|
||||
manager = LRUCacheLoRAModelManager(
|
||||
model, 2, 2, 2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
|
||||
["dense1", "dense2", "lm_head"])
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
|
||||
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
|
||||
|
@ -289,8 +293,9 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
|||
sql_lora_files):
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||
worker_lora_manager = LRUCacheWorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
|
||||
torch.device("cuda"))
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
|
@ -362,8 +367,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
|||
# Should remove every LoRA not specified in the request.
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||
worker_lora_manager = WorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
|
||||
torch.device("cuda"))
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
|
@ -428,6 +434,13 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
|||
|
||||
def test_packed_loras(dist_init, dummy_model_gate_up):
|
||||
model = dummy_model_gate_up
|
||||
model.supported_lora_modules = ["gate_up_proj"]
|
||||
model.packed_modules_mapping = {
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
model_lora = create_packed_lora(
|
||||
1,
|
||||
model,
|
||||
|
@ -443,8 +456,7 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
|
|||
|
||||
manager = LoRAModelManager(
|
||||
model, 2, 2, 2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
|
||||
["gate_up_proj"])
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
|
||||
model = manager.model
|
||||
|
||||
assert isinstance(model.get_submodule("gate_up_proj"),
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
|
||||
|
||||
def do_sample(llm, lora_path: str, lora_id: int):
|
||||
prompts = [
|
||||
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]",
|
||||
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]",
|
||||
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]",
|
||||
]
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||
if lora_id else None)
|
||||
# Print the outputs.
|
||||
generated_texts = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [4])
|
||||
def test_mixtral_lora(mixtral_lora_files, tp_size):
|
||||
if torch.cuda.device_count() < tp_size:
|
||||
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
|
||||
llm = vllm.LLM(MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=tp_size,
|
||||
worker_use_ray=True)
|
||||
|
||||
expected_lora_output = [
|
||||
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])",
|
||||
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])",
|
||||
"inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])",
|
||||
]
|
||||
|
||||
assert do_sample(llm, mixtral_lora_files,
|
||||
lora_id=1) == expected_lora_output
|
||||
assert do_sample(llm, mixtral_lora_files,
|
||||
lora_id=2) == expected_lora_output
|
|
@ -4,8 +4,7 @@ import logging
|
|||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type,
|
||||
Union)
|
||||
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type)
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
@ -20,36 +19,6 @@ from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: The mappings below should be moved to individual model classes.
|
||||
|
||||
PACKED_MODULES_CFG = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
TARGET_MODULES_QKV = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
|
||||
EMBEDDING_MODULES = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
|
||||
EMBEDDING_PADDING_MODULES = ["lm_head"]
|
||||
|
||||
_GLOBAL_LORA_ID = 0
|
||||
|
||||
|
||||
|
@ -169,6 +138,8 @@ class LoRAModel:
|
|||
dtype: Optional[torch.dtype] = None,
|
||||
embeddings: Optional[Dict[str, torch.Tensor]] = None,
|
||||
target_embedding_padding: Optional[int] = None,
|
||||
embedding_modules: Optional[Dict[str, str]] = None,
|
||||
embedding_padding_modules: Optional[List[str]] = None,
|
||||
) -> "LoRAModel":
|
||||
"""Create a LoRAModel from a dictionary of tensors."""
|
||||
pin_memory = str(device) == "cpu" and not in_wsl()
|
||||
|
@ -179,11 +150,11 @@ class LoRAModel:
|
|||
lora_embeddings_tensor = None
|
||||
if embeddings:
|
||||
embeddings_module = next(
|
||||
(k for k in EMBEDDING_MODULES if k in module_name),
|
||||
(k for k in embedding_modules if k in module_name),
|
||||
None)
|
||||
if embeddings_module:
|
||||
lora_embeddings_tensor = embeddings[
|
||||
EMBEDDING_MODULES[embeddings_module]].to(
|
||||
embedding_modules[embeddings_module]].to(
|
||||
device=device, dtype=dtype)
|
||||
if pin_memory:
|
||||
lora_embeddings_tensor = (
|
||||
|
@ -201,7 +172,7 @@ class LoRAModel:
|
|||
loras[module_name].lora_b = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
if any(name in module_name
|
||||
for name in EMBEDDING_PADDING_MODULES
|
||||
for name in embedding_padding_modules
|
||||
) and target_embedding_padding is not None:
|
||||
lora_b = loras[module_name].lora_b
|
||||
assert target_embedding_padding >= lora_b.shape[1]
|
||||
|
@ -218,12 +189,15 @@ class LoRAModel:
|
|||
|
||||
@classmethod
|
||||
def from_local_checkpoint(
|
||||
cls,
|
||||
lora_dir: str,
|
||||
lora_model_id: Optional[int] = None,
|
||||
device: str = "cuda",
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
target_embedding_padding: Optional[int] = None) -> "LoRAModel":
|
||||
cls,
|
||||
lora_dir: str,
|
||||
lora_model_id: Optional[int] = None,
|
||||
device: str = "cuda",
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
target_embedding_padding: Optional[int] = None,
|
||||
embedding_modules: Optional[Dict[str, str]] = None,
|
||||
embedding_padding_modules: Optional[List[str]] = None,
|
||||
) -> "LoRAModel":
|
||||
"""Create a LoRAModel from a local checkpoint."""
|
||||
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
|
||||
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
||||
|
@ -260,6 +234,8 @@ class LoRAModel:
|
|||
dtype=dtype,
|
||||
embeddings=embeddings,
|
||||
target_embedding_padding=target_embedding_padding,
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embedding_padding_modules,
|
||||
)
|
||||
|
||||
|
||||
|
@ -273,8 +249,6 @@ class LoRAModelManager:
|
|||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
||||
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
|
||||
):
|
||||
"""Create a LoRAModelManager and adapter for a given model.
|
||||
|
||||
|
@ -286,13 +260,6 @@ class LoRAModelManager:
|
|||
in a single batch.
|
||||
vocab_size: the vocab size of the model.
|
||||
lora_config: the LoRA configuration.
|
||||
lora_target_modules: the target modules patterns to be adapted.
|
||||
Support both single module name and a list of module names.
|
||||
packed_modules_mapping: the mapping for packed modules. vLLM
|
||||
packs some modules into one module, e.g., qkv_proj
|
||||
is packed of q_proj, k_proj, and v_proj. These modules
|
||||
have a single layer in the original model, but they are split
|
||||
into multiple layers in the adapted model.
|
||||
"""
|
||||
self.lora_config = lora_config
|
||||
self.max_num_seqs = max_num_seqs
|
||||
|
@ -320,11 +287,11 @@ class LoRAModelManager:
|
|||
self.indices_len = [None] * 4
|
||||
|
||||
self.model: nn.Module = model
|
||||
self.lora_target_modules: List[str] = ([
|
||||
lora_target_modules
|
||||
] if isinstance(lora_target_modules, str) else lora_target_modules)
|
||||
self.lora_target_modules = copy.deepcopy(lora_target_modules)
|
||||
self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping)
|
||||
if hasattr(self.model, "supported_lora_modules"):
|
||||
self.supported_lora_modules = copy.deepcopy(
|
||||
self.model.supported_lora_modules)
|
||||
self.packed_modules_mapping = copy.deepcopy(
|
||||
self.model.packed_modules_mapping)
|
||||
self.packed_modules: Dict[str, List[str]] = {}
|
||||
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
|
||||
self._registered_loras: Dict[int, LoRAModel] = {}
|
||||
|
@ -468,7 +435,11 @@ class LoRAModelManager:
|
|||
assert isinstance(module, BaseLayerWithLoRA)
|
||||
self.modules[module_name] = module
|
||||
|
||||
def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel:
|
||||
def create_dummy_lora(
|
||||
self,
|
||||
lora_id: int,
|
||||
rank: int,
|
||||
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
|
||||
"""Create zero-initialized LoRAModel for warmup."""
|
||||
model = LoRAModel(lora_id, rank, {})
|
||||
for module_name, module in self.model.named_modules():
|
||||
|
@ -477,7 +448,7 @@ class LoRAModelManager:
|
|||
continue
|
||||
parts = module_name.split(".")
|
||||
if module_name not in self.packed_modules:
|
||||
if parts[-1] in EMBEDDING_MODULES:
|
||||
if parts[-1] in embedding_modules:
|
||||
input_dim = (module.base_layer.org_vocab_size +
|
||||
self.lora_config.lora_extra_vocab_size if
|
||||
hasattr(module.base_layer, "org_vocab_size")
|
||||
|
@ -531,7 +502,7 @@ class LoRAModelManager:
|
|||
re.match(
|
||||
r".*\.{target_module}$".format(target_module=target_module),
|
||||
module_name) or target_module == module_name
|
||||
for target_module in self.lora_target_modules)
|
||||
for target_module in self.supported_lora_modules)
|
||||
|
||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||
parts = module_full_name.split(".")
|
||||
|
@ -586,12 +557,9 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
|||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
||||
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
|
||||
):
|
||||
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||
vocab_size, lora_config, lora_target_modules,
|
||||
packed_modules_mapping)
|
||||
vocab_size, lora_config)
|
||||
self._registered_loras: LoRALRUCache = LoRALRUCache(
|
||||
self.capacity, self.deactivate_lora)
|
||||
self._active_loras: LoRALRUCache = LoRALRUCache(
|
||||
|
@ -637,11 +605,10 @@ def create_lora_manager(
|
|||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
||||
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
|
||||
**kwargs) -> LoRAModelManager:
|
||||
"""Create a LoRA adapter for a given model."""
|
||||
if not getattr(model, "supports_lora", False):
|
||||
if not hasattr(model, "supported_lora_modules"):
|
||||
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
|
||||
lora_manager = lora_manager_cls(
|
||||
model=model,
|
||||
|
@ -649,6 +616,5 @@ def create_lora_manager(
|
|||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
vocab_size=vocab_size,
|
||||
lora_config=lora_config,
|
||||
lora_target_modules=target_modules,
|
||||
**kwargs)
|
||||
return lora_manager
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import logging
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from typing import Any, List, Optional, Set, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager,
|
||||
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager, create_lora_manager)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
|
@ -13,7 +13,7 @@ from vllm.config import LoRAConfig
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerLoRAManager(ABC):
|
||||
class AbstractWorkerLoRAManager(ABC):
|
||||
"""Abstract class for managing LoRA models on the worker side."""
|
||||
|
||||
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
|
||||
|
@ -33,7 +33,6 @@ class WorkerLoRAManager(ABC):
|
|||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
|
@ -63,7 +62,7 @@ class WorkerLoRAManager(ABC):
|
|||
...
|
||||
|
||||
|
||||
class WorkerLoRAManager(WorkerLoRAManager):
|
||||
class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
||||
"""WorkerLoRAManager that manages LoRA models on the worker side.
|
||||
|
||||
Every request, the requested LoRAs will be loaded (unless they are already
|
||||
|
@ -78,10 +77,14 @@ class WorkerLoRAManager(WorkerLoRAManager):
|
|||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
device: torch.device,
|
||||
embedding_modules: Dict[str, str],
|
||||
embedding_padding_modules: List[str],
|
||||
lora_model_cls: Type[LoRAModel] = LoRAModel,
|
||||
):
|
||||
self._lora_manager: Optional[LoRAModelManager] = None
|
||||
self._lora_model_cls = lora_model_cls
|
||||
self.embedding_modules = embedding_modules
|
||||
self.embedding_padding_modules = embedding_padding_modules
|
||||
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
|
||||
lora_config, device)
|
||||
|
||||
|
@ -92,13 +95,11 @@ class WorkerLoRAManager(WorkerLoRAManager):
|
|||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
target_modules=target_modules,
|
||||
vocab_size=self.vocab_size,
|
||||
lora_config=self.lora_config,
|
||||
lora_manager_cls=self._lora_manager_cls,
|
||||
|
@ -142,6 +143,8 @@ class WorkerLoRAManager(WorkerLoRAManager):
|
|||
dtype=self.lora_config.lora_dtype,
|
||||
target_embedding_padding=self.vocab_size +
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
embedding_modules=self.embedding_modules,
|
||||
embedding_padding_modules=self.embedding_padding_modules,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
|
@ -162,7 +165,7 @@ class WorkerLoRAManager(WorkerLoRAManager):
|
|||
return False
|
||||
return self._lora_manager.add_lora(
|
||||
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
|
||||
rank))
|
||||
rank, self.embedding_modules))
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if lora_request.lora_int_id in self.list_loras():
|
||||
|
@ -195,11 +198,9 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
|||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
target_modules=target_modules,
|
||||
lora_manager_cls=self._lora_manager_cls,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
vocab_size=self.vocab_size,
|
||||
|
|
|
@ -66,7 +66,7 @@ def get_model(model_config: ModelConfig,
|
|||
# Create a model instance.
|
||||
# The weights will be initialized as empty tensors.
|
||||
with torch.device(device_config.device):
|
||||
if getattr(model_class, "supports_lora", False):
|
||||
if hasattr(model_class, "supported_lora_modules"):
|
||||
model = model_class(model_config.hf_config, linear_method,
|
||||
lora_config)
|
||||
elif lora_config:
|
||||
|
|
|
@ -269,7 +269,32 @@ class LlamaModel(nn.Module):
|
|||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
supports_lora = True
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -281,11 +306,11 @@ class LlamaForCausalLM(nn.Module):
|
|||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = LlamaModel(config, linear_method, lora_config=lora_config)
|
||||
unpadded_vocab_size = config.vocab_size
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
unpadded_vocab_size,
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
|
@ -293,7 +318,7 @@ class LlamaForCausalLM(nn.Module):
|
|||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
)
|
||||
self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
|
||||
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -265,7 +265,32 @@ class MistralModel(nn.Module):
|
|||
|
||||
|
||||
class MistralForCausalLM(nn.Module):
|
||||
supports_lora = True
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -27,6 +27,7 @@ import torch
|
|||
from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
|
@ -38,7 +39,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
|
@ -292,6 +293,7 @@ class MixtralModel(nn.Module):
|
|||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=self.org_vocab_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
MixtralDecoderLayer(config, linear_method=linear_method)
|
||||
|
@ -318,18 +320,50 @@ class MixtralModel(nn.Module):
|
|||
|
||||
|
||||
class MixtralForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = MixtralModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
)
|
||||
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -86,11 +86,20 @@ class ModelRunner:
|
|||
vocab_size = self.model.config.vocab_size
|
||||
|
||||
if self.lora_config:
|
||||
assert hasattr(
|
||||
self.model, "supported_lora_modules"
|
||||
) and self.model.supported_lora_modules, "Model does not support LoRA"
|
||||
assert hasattr(
|
||||
self.model,
|
||||
"embedding_modules"), "Model does not have embedding_modules"
|
||||
assert hasattr(self.model, "embedding_padding_modules"
|
||||
), "Model does not have embedding_padding_modules"
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens +
|
||||
self.scheduler_config.max_paddings, vocab_size,
|
||||
self.lora_config, self.device)
|
||||
self.lora_config, self.device, self.model.embedding_modules,
|
||||
self.model.embedding_padding_modules)
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
|
||||
def set_block_size(self, block_size: int) -> None:
|
||||
|
|
Loading…
Reference in New Issue