mirror of https://github.com/vllm-project/vllm
[mypy] Enable following imports for some directories (#6681)
This commit is contained in:
parent
c32ab8be1a
commit
da1f7cc12a
|
@ -32,22 +32,17 @@ jobs:
|
|||
pip install types-setuptools
|
||||
- name: Mypy
|
||||
run: |
|
||||
mypy tests --config-file pyproject.toml
|
||||
mypy vllm/*.py --config-file pyproject.toml
|
||||
mypy vllm/attention --config-file pyproject.toml
|
||||
mypy vllm/core --config-file pyproject.toml
|
||||
mypy vllm/distributed --config-file pyproject.toml
|
||||
mypy vllm/engine --config-file pyproject.toml
|
||||
mypy vllm/entrypoints --config-file pyproject.toml
|
||||
mypy vllm/executor --config-file pyproject.toml
|
||||
mypy vllm/inputs --config-file pyproject.toml
|
||||
mypy vllm/logging --config-file pyproject.toml
|
||||
mypy vllm/lora --config-file pyproject.toml
|
||||
mypy vllm/model_executor --config-file pyproject.toml
|
||||
mypy vllm/multimodal --config-file pyproject.toml
|
||||
mypy vllm/platforms --config-file pyproject.toml
|
||||
mypy vllm/spec_decode --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||
mypy vllm/usage --config-file pyproject.toml
|
||||
mypy vllm/worker --config-file pyproject.toml
|
||||
mypy tests --follow-imports skip
|
||||
mypy vllm/attention --follow-imports skip
|
||||
mypy vllm/core --follow-imports skip
|
||||
mypy vllm/distributed --follow-imports skip
|
||||
mypy vllm/engine --follow-imports skip
|
||||
mypy vllm/entrypoints --follow-imports skip
|
||||
mypy vllm/executor --follow-imports skip
|
||||
mypy vllm/lora --follow-imports skip
|
||||
mypy vllm/model_executor --follow-imports skip
|
||||
mypy vllm/prompt_adapter --follow-imports skip
|
||||
mypy vllm/spec_decode --follow-imports skip
|
||||
mypy vllm/worker --follow-imports skip
|
||||
mypy
|
||||
|
||||
|
|
30
format.sh
30
format.sh
|
@ -96,23 +96,19 @@ echo 'vLLM yapf: Done'
|
|||
|
||||
# Run mypy
|
||||
echo 'vLLM mypy:'
|
||||
mypy tests --config-file pyproject.toml
|
||||
mypy vllm/*.py --config-file pyproject.toml
|
||||
mypy vllm/attention --config-file pyproject.toml
|
||||
mypy vllm/core --config-file pyproject.toml
|
||||
mypy vllm/distributed --config-file pyproject.toml
|
||||
mypy vllm/engine --config-file pyproject.toml
|
||||
mypy vllm/entrypoints --config-file pyproject.toml
|
||||
mypy vllm/executor --config-file pyproject.toml
|
||||
mypy vllm/logging --config-file pyproject.toml
|
||||
mypy vllm/lora --config-file pyproject.toml
|
||||
mypy vllm/model_executor --config-file pyproject.toml
|
||||
mypy vllm/multimodal --config-file pyproject.toml
|
||||
mypy vllm/prompt_adapter --config-file pyproject.toml
|
||||
mypy vllm/spec_decode --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||
mypy vllm/usage --config-file pyproject.toml
|
||||
mypy vllm/worker --config-file pyproject.toml
|
||||
mypy tests --follow-imports skip
|
||||
mypy vllm/attention --follow-imports skip
|
||||
mypy vllm/core --follow-imports skip
|
||||
mypy vllm/distributed --follow-imports skip
|
||||
mypy vllm/engine --follow-imports skip
|
||||
mypy vllm/entrypoints --follow-imports skip
|
||||
mypy vllm/executor --follow-imports skip
|
||||
mypy vllm/lora --follow-imports skip
|
||||
mypy vllm/model_executor --follow-imports skip
|
||||
mypy vllm/prompt_adapter --follow-imports skip
|
||||
mypy vllm/spec_decode --follow-imports skip
|
||||
mypy vllm/worker --follow-imports skip
|
||||
mypy
|
||||
|
||||
|
||||
# If git diff returns a file that is in the skip list, the file may be checked anyway:
|
||||
|
|
|
@ -48,9 +48,23 @@ python_version = "3.8"
|
|||
|
||||
ignore_missing_imports = true
|
||||
check_untyped_defs = true
|
||||
follow_imports = "skip"
|
||||
follow_imports = "silent"
|
||||
|
||||
files = "vllm"
|
||||
# After fixing type errors resulting from follow_imports: "skip" -> "silent",
|
||||
# move the directory here and remove it from format.sh and mypy.yaml
|
||||
files = [
|
||||
"vllm/*.py",
|
||||
"vllm/adapter_commons",
|
||||
"vllm/assets",
|
||||
"vllm/inputs",
|
||||
"vllm/logging",
|
||||
"vllm/multimodal",
|
||||
"vllm/platforms",
|
||||
"vllm/server",
|
||||
"vllm/transformers_utils",
|
||||
"vllm/triton_utils",
|
||||
"vllm/usage",
|
||||
]
|
||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||
exclude = [
|
||||
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
|
||||
|
|
|
@ -239,7 +239,7 @@ def cutlass_scaled_mm(a: torch.Tensor,
|
|||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype],
|
||||
out_dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
||||
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||
|
|
|
@ -25,27 +25,33 @@ class ipex_ops:
|
|||
x2 = x2.reshape(num, d)
|
||||
return x1, x2
|
||||
|
||||
@staticmethod
|
||||
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
||||
ipex.llm.functional.silu_mul(x1, x2, out)
|
||||
|
||||
@staticmethod
|
||||
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
||||
ipex.llm.functional.gelu_mul(x1, x2, out, "none")
|
||||
|
||||
@staticmethod
|
||||
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
||||
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
|
||||
|
||||
@staticmethod
|
||||
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
out.copy_(torch.nn.functional.gelu(x))
|
||||
|
||||
@staticmethod
|
||||
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
out.copy_(torch.nn.functional.gelu(x))
|
||||
|
||||
# TODO add implementation of gelu_quick here
|
||||
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
|
||||
@staticmethod
|
||||
def paged_attention_v1(
|
||||
out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
|
@ -78,12 +84,21 @@ class ipex_ops:
|
|||
).view(num_kv_heads,
|
||||
1).repeat_interleave(num_queries_per_tokens).flatten()
|
||||
# todo: ipex will refactor namespace
|
||||
torch.xpu.paged_attention_v1(out, query.contiguous(),
|
||||
key_cache.view_as(value_cache),
|
||||
value_cache, head_mapping, scale,
|
||||
block_tables, context_lens, block_size,
|
||||
max_context_len, alibi_slopes)
|
||||
torch.xpu.paged_attention_v1( # type: ignore
|
||||
out,
|
||||
query.contiguous(),
|
||||
key_cache.view_as(value_cache),
|
||||
value_cache,
|
||||
head_mapping,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def paged_attention_v2(
|
||||
out: torch.Tensor,
|
||||
exp_sum: torch.Tensor,
|
||||
|
@ -119,13 +134,24 @@ class ipex_ops:
|
|||
).view(num_kv_heads,
|
||||
1).repeat_interleave(num_queries_per_tokens).flatten()
|
||||
# todo: ipex will refactor namespace
|
||||
torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out,
|
||||
query.contiguous(),
|
||||
key_cache.view_as(value_cache),
|
||||
value_cache, head_mapping, block_tables,
|
||||
context_lens, scale, block_size,
|
||||
max_context_len, alibi_slopes)
|
||||
torch.xpu.paged_attention_v2( # type: ignore
|
||||
out,
|
||||
exp_sum,
|
||||
max_logits,
|
||||
tmp_out,
|
||||
query.contiguous(),
|
||||
key_cache.view_as(value_cache),
|
||||
value_cache,
|
||||
head_mapping,
|
||||
block_tables,
|
||||
context_lens,
|
||||
scale,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def rotary_embedding(
|
||||
positions: torch.Tensor, # [batch_size, seq_len]
|
||||
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
|
||||
|
@ -158,6 +184,7 @@ class ipex_ops:
|
|||
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
|
||||
rotary_dim, is_neox, positions)
|
||||
|
||||
@staticmethod
|
||||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
||||
key: torch.Tensor, head_size: int,
|
||||
cos_sin_cache: torch.Tensor, is_neox: bool,
|
||||
|
@ -189,17 +216,20 @@ class ipex_ops:
|
|||
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
|
||||
rotary_dim, is_neox, positions)
|
||||
|
||||
@staticmethod
|
||||
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
||||
epsilon: float) -> None:
|
||||
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
|
||||
out.copy_(tmp)
|
||||
|
||||
@staticmethod
|
||||
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
||||
weight: torch.Tensor, epsilon: float) -> None:
|
||||
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
|
||||
epsilon, True)
|
||||
input.copy_(tmp)
|
||||
|
||||
@staticmethod
|
||||
def varlen_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
|
@ -222,6 +252,7 @@ class ipex_ops:
|
|||
softmax_scale, zero_tensors,
|
||||
is_causal, return_softmax, gen_)
|
||||
|
||||
@staticmethod
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
|
@ -240,8 +271,13 @@ class ipex_ops:
|
|||
def copy_blocks(key_caches: List[torch.Tensor],
|
||||
value_caches: List[torch.Tensor],
|
||||
block_mapping: torch.Tensor) -> None:
|
||||
torch.xpu.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
torch.xpu.copy_blocks( # type: ignore
|
||||
key_caches,
|
||||
value_caches,
|
||||
block_mapping,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
||||
block_mapping: torch.Tensor) -> None:
|
||||
torch.xpu.swap_blocks(src, dst, block_mapping)
|
||||
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
|
||||
|
|
|
@ -31,7 +31,7 @@ class AdapterLRUCache(LRUCache[T]):
|
|||
super().__init__(capacity)
|
||||
self.deactivate_fn = deactivate_fn
|
||||
|
||||
def _on_remove(self, key: Hashable, value: T):
|
||||
def _on_remove(self, key: Hashable, value: Optional[T]):
|
||||
logger.debug("Removing adapter int id: %d", key)
|
||||
self.deactivate_fn(key)
|
||||
return super()._on_remove(key, value)
|
||||
|
@ -59,46 +59,46 @@ class AdapterModelManager(ABC):
|
|||
|
||||
@property
|
||||
@abstractmethod
|
||||
def adapter_slots(self):
|
||||
...
|
||||
def adapter_slots(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def capacity(self):
|
||||
...
|
||||
def capacity(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def activate_adapter(self, adapter_id: int) -> bool:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def deactivate_adapter(self, adapter_id: int) -> bool:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_adapter(self, adapter: Any) -> bool:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set_adapter_mapping(self, mapping: Any) -> None:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_all_adapters(self):
|
||||
...
|
||||
def remove_all_adapters(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_adapter(self, adapter_id: int) -> Optional[Any]:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def list_adapters(self) -> Dict[int, Any]:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def pin_adapter(self, adapter_id: int) -> bool:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,19 +1,19 @@
|
|||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterRequest:
|
||||
class AdapterRequest(ABC):
|
||||
"""
|
||||
Base class for adapter requests.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def adapter_id(self):
|
||||
...
|
||||
def adapter_id(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.adapter_id < 1:
|
||||
raise ValueError(f"id must be > 0, got {self.adapter_id}")
|
||||
|
||||
|
|
|
@ -12,25 +12,25 @@ class AbstractWorkerManager(ABC):
|
|||
@property
|
||||
@abstractmethod
|
||||
def is_enabled(self) -> bool:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set_active_adapters(self, requests: Set[Any],
|
||||
mapping: Optional[Any]) -> None:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_adapter(self, adapter_request: Any) -> bool:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_all_adapters(self):
|
||||
...
|
||||
def remove_all_adapters(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def list_adapters(self) -> Set[int]:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -724,7 +724,7 @@ class ParallelConfig:
|
|||
backend)
|
||||
|
||||
self._verify_args()
|
||||
self.rank = 0
|
||||
self.rank: int = 0
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
|
@ -850,6 +850,7 @@ class SchedulerConfig:
|
|||
|
||||
|
||||
class DeviceConfig:
|
||||
device: Optional[torch.device]
|
||||
|
||||
def __init__(self, device: str = "auto") -> None:
|
||||
if device == "auto":
|
||||
|
|
|
@ -5,8 +5,6 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
|
|||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, TypeVar, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
|
@ -40,7 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
|||
init_tracer)
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
from vllm.transformers_utils.tokenizer_group import (AnyTokenizer,
|
||||
BaseTokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
|
@ -477,13 +476,12 @@ class LLMEngine:
|
|||
return self.tokenizer
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
|
||||
|
||||
def get_tokenizer_for_seq(self,
|
||||
sequence: Sequence) -> "PreTrainedTokenizer":
|
||||
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(
|
||||
sequence.lora_request)
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ from http import HTTPStatus
|
|||
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from pydantic import Field
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
|
@ -30,6 +29,7 @@ from vllm.pooling_params import PoolingParams
|
|||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -49,8 +49,6 @@ class LoRAModulePath:
|
|||
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
|
||||
EmbeddingRequest, TokenizeRequest]
|
||||
|
||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
|
||||
|
||||
class TextTokensPrompt(TypedDict):
|
||||
prompt: str
|
||||
|
|
|
@ -4,9 +4,10 @@ import asyncio
|
|||
import os
|
||||
import signal
|
||||
import sys
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from vllm.entrypoints.openai.api_server import run_server
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
|
@ -63,15 +64,14 @@ def complete(model_name: str, client: OpenAI) -> None:
|
|||
|
||||
def chat(system_prompt: Optional[str], model_name: str,
|
||||
client: OpenAI) -> None:
|
||||
conversation = []
|
||||
conversation: List[ChatCompletionMessageParam] = []
|
||||
if system_prompt is not None:
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
|
||||
print("Please enter a message for the chat model:")
|
||||
while True:
|
||||
input_message = input("> ")
|
||||
message = {"role": "user", "content": input_message}
|
||||
conversation.append(message)
|
||||
conversation.append({"role": "user", "content": input_message})
|
||||
|
||||
chat_completion = client.chat.completions.create(model=model_name,
|
||||
messages=conversation)
|
||||
|
@ -79,7 +79,7 @@ def chat(system_prompt: Optional[str], model_name: str,
|
|||
response_message = chat_completion.choices[0].message
|
||||
output = response_message.content
|
||||
|
||||
conversation.append(response_message)
|
||||
conversation.append(response_message) # type: ignore
|
||||
print(output)
|
||||
|
||||
|
||||
|
|
|
@ -37,6 +37,8 @@ class Detokenizer:
|
|||
The prompt logprobs with the decoded tokens.
|
||||
"""
|
||||
prms = seq_group.sampling_params
|
||||
assert prms is not None
|
||||
|
||||
# We can pick any sequence for the prompt.
|
||||
seq = next(iter(seq_group.seqs_dict.values()))
|
||||
# Only prompt, without the generated token.
|
||||
|
|
|
@ -2,10 +2,9 @@ from typing import Optional, Type
|
|||
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.executor.ray_utils import ray
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
|
||||
TokenizerGroup)
|
||||
|
||||
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
|
||||
from .tokenizer_group import TokenizerGroup
|
||||
|
||||
if ray:
|
||||
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
|
||||
|
@ -34,4 +33,4 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
|||
return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
|
||||
|
||||
|
||||
__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]
|
||||
__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"]
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
|
||||
|
||||
class BaseTokenizerGroup(ABC):
|
||||
"""A group of tokenizers that can be used for LoRA adapters."""
|
||||
|
@ -47,17 +49,17 @@ class BaseTokenizerGroup(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def get_lora_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
"""Get a tokenizer for a LoRA request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_lora_tokenizer_async(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
"""Get a tokenizer for a LoRA request."""
|
||||
pass
|
||||
|
||||
|
|
|
@ -6,18 +6,16 @@ try:
|
|||
from ray.exceptions import ActorDiedError
|
||||
except ImportError:
|
||||
# For older versions of Ray
|
||||
from ray.exceptions import RayActorError as ActorDiedError
|
||||
from ray.exceptions import RayActorError as ActorDiedError # type: ignore
|
||||
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.executor.ray_utils import ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
|
||||
TokenizerGroup)
|
||||
|
||||
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
|
||||
from .tokenizer_group import TokenizerGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -67,7 +65,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|||
**self._tokenizer_config, )
|
||||
|
||||
self._ray_tokenizer_group_cls = ray.remote(
|
||||
self._worker_cls).options(**ray_actor_options)
|
||||
self._worker_cls).options(**ray_actor_options) # type: ignore
|
||||
self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
|
||||
self._idle_actors: Optional[asyncio.Queue] = None
|
||||
|
||||
|
@ -83,8 +81,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|||
return len(self.tokenizer_actors)
|
||||
|
||||
def ping(self):
|
||||
return ray.get(
|
||||
[actor.ping.remote() for actor in self.tokenizer_actors])
|
||||
return ray.get([
|
||||
actor.ping.remote() # type: ignore
|
||||
for actor in self.tokenizer_actors
|
||||
])
|
||||
|
||||
def _ensure_queue_initialized(self):
|
||||
if self._idle_actors is None:
|
||||
|
@ -208,15 +208,15 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|||
return self._local_tokenizer_group.get_max_input_len(lora_request)
|
||||
|
||||
def get_lora_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
return self._local_tokenizer_group.get_lora_tokenizer(lora_request)
|
||||
|
||||
async def get_lora_tokenizer_async(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
return await self._local_tokenizer_group.get_lora_tokenizer_async(
|
||||
lora_request)
|
||||
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
|
||||
get_lora_tokenizer_async,
|
||||
get_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
from vllm.utils import LRUCache
|
||||
|
||||
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
|
||||
|
||||
|
||||
class TokenizerGroup(BaseTokenizerGroup):
|
||||
"""A group of tokenizers that can be used for LoRA adapters."""
|
||||
|
@ -22,8 +20,8 @@ class TokenizerGroup(BaseTokenizerGroup):
|
|||
self.enable_lora = enable_lora
|
||||
self.max_input_length = max_input_length
|
||||
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
||||
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
|
||||
capacity=max_num_seqs) if enable_lora else None
|
||||
self.lora_tokenizers = LRUCache[AnyTokenizer](
|
||||
capacity=max_num_seqs if enable_lora else 0)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
||||
|
@ -41,7 +39,7 @@ class TokenizerGroup(BaseTokenizerGroup):
|
|||
return self.max_input_length
|
||||
|
||||
def _raise_if_input_too_long(self,
|
||||
encoded_tokens: List[str],
|
||||
encoded_tokens: List[int],
|
||||
lora_request: Optional[LoRARequest] = None):
|
||||
input_length = len(encoded_tokens)
|
||||
if lora_request:
|
||||
|
@ -72,9 +70,9 @@ class TokenizerGroup(BaseTokenizerGroup):
|
|||
return ret
|
||||
|
||||
def get_lora_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
if not lora_request or not self.enable_lora:
|
||||
return self.tokenizer
|
||||
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||
|
@ -83,12 +81,12 @@ class TokenizerGroup(BaseTokenizerGroup):
|
|||
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
|
||||
return tokenizer
|
||||
else:
|
||||
return self.lora_tokenizers.get(lora_request.lora_int_id)
|
||||
return self.lora_tokenizers[lora_request.lora_int_id]
|
||||
|
||||
async def get_lora_tokenizer_async(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
if not lora_request or not self.enable_lora:
|
||||
return self.tokenizer
|
||||
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||
|
@ -97,4 +95,4 @@ class TokenizerGroup(BaseTokenizerGroup):
|
|||
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
|
||||
return tokenizer
|
||||
else:
|
||||
return self.lora_tokenizers.get(lora_request.lora_int_id)
|
||||
return self.lora_tokenizers[lora_request.lora_int_id]
|
||||
|
|
|
@ -94,8 +94,10 @@ class LRUCache(Generic[T]):
|
|||
def __len__(self) -> int:
|
||||
return len(self.cache)
|
||||
|
||||
def __getitem__(self, key: Hashable) -> Optional[T]:
|
||||
return self.get(key)
|
||||
def __getitem__(self, key: Hashable) -> T:
|
||||
value = self.cache[key] # Raise KeyError if not exists
|
||||
self.cache.move_to_end(key)
|
||||
return value
|
||||
|
||||
def __setitem__(self, key: Hashable, value: T) -> None:
|
||||
self.put(key, value)
|
||||
|
@ -109,8 +111,9 @@ class LRUCache(Generic[T]):
|
|||
def get(self,
|
||||
key: Hashable,
|
||||
default_value: Optional[T] = None) -> Optional[T]:
|
||||
value: Optional[T]
|
||||
if key in self.cache:
|
||||
value: Optional[T] = self.cache[key]
|
||||
value = self.cache[key]
|
||||
self.cache.move_to_end(key)
|
||||
else:
|
||||
value = default_value
|
||||
|
@ -590,8 +593,8 @@ class CudaMemoryProfiler:
|
|||
torch.cuda.reset_peak_memory_stats(self.device)
|
||||
mem = torch.cuda.max_memory_allocated(self.device)
|
||||
elif is_xpu():
|
||||
torch.xpu.reset_peak_memory_stats(self.device)
|
||||
mem = torch.xpu.max_memory_allocated(self.device)
|
||||
torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
|
||||
mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
|
||||
return mem
|
||||
|
||||
def __enter__(self):
|
||||
|
|
Loading…
Reference in New Issue