mirror of https://github.com/vllm-project/vllm
Implement lazy model loader (#2044)
This commit is contained in:
parent
30bad5c492
commit
518369d78c
|
@ -7,54 +7,9 @@ import torch.nn as nn
|
|||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import *
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||
initialize_dummy_weights)
|
||||
from vllm.utils import is_hip
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# TODO(woosuk): Lazy-load the model classes.
|
||||
_MODEL_REGISTRY = {
|
||||
"AquilaModel": AquilaForCausalLM,
|
||||
"AquilaForCausalLM": AquilaForCausalLM, # AquilaChat2
|
||||
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
|
||||
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
|
||||
"BloomForCausalLM": BloomForCausalLM,
|
||||
"ChatGLMModel": ChatGLMForCausalLM,
|
||||
"ChatGLMForConditionalGeneration": ChatGLMForCausalLM,
|
||||
"FalconForCausalLM": FalconForCausalLM,
|
||||
"GPT2LMHeadModel": GPT2LMHeadModel,
|
||||
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
|
||||
"GPTJForCausalLM": GPTJForCausalLM,
|
||||
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
|
||||
"InternLMForCausalLM": InternLMForCausalLM,
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
||||
"MistralForCausalLM": MistralForCausalLM,
|
||||
"MixtralForCausalLM": MixtralForCausalLM,
|
||||
# transformers's mpt class has lower case
|
||||
"MptForCausalLM": MPTForCausalLM,
|
||||
"MPTForCausalLM": MPTForCausalLM,
|
||||
"OPTForCausalLM": OPTForCausalLM,
|
||||
"PhiForCausalLM": PhiForCausalLM,
|
||||
"QWenLMHeadModel": QWenLMHeadModel,
|
||||
"RWForCausalLM": FalconForCausalLM,
|
||||
"YiForCausalLM": YiForCausalLM,
|
||||
}
|
||||
|
||||
# Models to be disabled in ROCm
|
||||
_ROCM_UNSUPPORTED_MODELS = []
|
||||
if is_hip():
|
||||
for rocm_model in _ROCM_UNSUPPORTED_MODELS:
|
||||
del _MODEL_REGISTRY[rocm_model]
|
||||
|
||||
# Models partially supported in ROCm
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
|
||||
"MistralForCausalLM":
|
||||
"Sliding window attention is not supported in ROCm's flash attention",
|
||||
}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
@ -69,19 +24,12 @@ def _set_default_torch_dtype(dtype: torch.dtype):
|
|||
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _MODEL_REGISTRY:
|
||||
if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
||||
logger.warning(
|
||||
f"{arch} is not fully supported in ROCm. Reason: "
|
||||
f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}")
|
||||
return _MODEL_REGISTRY[arch]
|
||||
elif arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model architecture {arch} is not supported by ROCm for now. \n"
|
||||
f"Supported architectures {list(_MODEL_REGISTRY.keys())}")
|
||||
model_cls = ModelRegistry.load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return model_cls
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
|
||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||
|
||||
|
||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
|
|
|
@ -1,41 +1,80 @@
|
|||
from vllm.model_executor.models.aquila import AquilaForCausalLM
|
||||
from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM,
|
||||
BaichuanForCausalLM)
|
||||
from vllm.model_executor.models.bloom import BloomForCausalLM
|
||||
from vllm.model_executor.models.falcon import FalconForCausalLM
|
||||
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
|
||||
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
|
||||
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
|
||||
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from vllm.model_executor.models.internlm import InternLMForCausalLM
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.models.mistral import MistralForCausalLM
|
||||
from vllm.model_executor.models.mixtral import MixtralForCausalLM
|
||||
from vllm.model_executor.models.mpt import MPTForCausalLM
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
from vllm.model_executor.models.phi_1_5 import PhiForCausalLM
|
||||
from vllm.model_executor.models.qwen import QWenLMHeadModel
|
||||
from vllm.model_executor.models.chatglm import ChatGLMForCausalLM
|
||||
from vllm.model_executor.models.yi import YiForCausalLM
|
||||
import importlib
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Architecture -> (module, class).
|
||||
_MODELS = {
|
||||
"AquilaModel": ("aquila", "AquilaForCausalLM"),
|
||||
"AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2
|
||||
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
|
||||
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
|
||||
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
||||
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
||||
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
|
||||
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
||||
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
||||
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
|
||||
"InternLMForCausalLM": ("internlm", "InternLMForCausalLM"),
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
# For decapoda-research/llama-*
|
||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MistralForCausalLM": ("mistral", "MistralForCausalLM"),
|
||||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
||||
# transformers's mpt class has lower case
|
||||
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||
"PhiForCausalLM": ("phi_1_5", "PhiForCausalLM"),
|
||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"YiForCausalLM": ("yi", "YiForCausalLM"),
|
||||
}
|
||||
|
||||
# Models not supported by ROCm.
|
||||
_ROCM_UNSUPPORTED_MODELS = ["MixtralForCausalLM"]
|
||||
|
||||
# Models partially supported by ROCm.
|
||||
# Architecture -> Reason.
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
|
||||
"MistralForCausalLM":
|
||||
"Sliding window attention is not yet supported in ROCm's flash attention",
|
||||
}
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
|
||||
@staticmethod
|
||||
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
if model_arch not in _MODELS:
|
||||
return None
|
||||
if is_hip():
|
||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model architecture {model_arch} is not supported by "
|
||||
"ROCm for now.")
|
||||
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
||||
logger.warning(
|
||||
f"Model architecture {model_arch} is partially supported "
|
||||
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
|
||||
|
||||
module_name, model_cls_name = _MODELS[model_arch]
|
||||
module = importlib.import_module(
|
||||
f"vllm.model_executor.models.{module_name}")
|
||||
return getattr(module, model_cls_name, None)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_archs() -> List[str]:
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AquilaForCausalLM",
|
||||
"BaiChuanForCausalLM",
|
||||
"BaichuanForCausalLM",
|
||||
"BloomForCausalLM",
|
||||
"ChatGLMForCausalLM",
|
||||
"FalconForCausalLM",
|
||||
"GPT2LMHeadModel",
|
||||
"GPTBigCodeForCausalLM",
|
||||
"GPTJForCausalLM",
|
||||
"GPTNeoXForCausalLM",
|
||||
"InternLMForCausalLM",
|
||||
"LlamaForCausalLM",
|
||||
"MPTForCausalLM",
|
||||
"OPTForCausalLM",
|
||||
"PhiForCausalLM",
|
||||
"QWenLMHeadModel",
|
||||
"MistralForCausalLM",
|
||||
"MixtralForCausalLM",
|
||||
"YiForCausalLM",
|
||||
"ModelRegistry",
|
||||
]
|
||||
|
|
|
@ -33,14 +33,15 @@ from transformers import MixtralConfig
|
|||
|
||||
try:
|
||||
import megablocks.ops as ops
|
||||
except ImportError:
|
||||
print(
|
||||
"MegaBlocks not found. Please install it by `pip install megablocks`.")
|
||||
except ImportError as e:
|
||||
raise ImportError("MegaBlocks not found. "
|
||||
"Please install it by `pip install megablocks`.") from e
|
||||
try:
|
||||
import stk
|
||||
except ImportError:
|
||||
print(
|
||||
"STK not found: please see https://github.com/stanford-futuredata/stk")
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"STK not found. "
|
||||
"Please install it by `pip install stanford-stk`.") from e
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
|
|
Loading…
Reference in New Issue