[mypy] Enable following imports for some directories (#6681)

This commit is contained in:
Cyrus Leung 2024-07-31 10:38:03 +08:00 committed by GitHub
parent c32ab8be1a
commit da1f7cc12a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 185 additions and 143 deletions

View File

@ -32,22 +32,17 @@ jobs:
pip install types-setuptools
- name: Mypy
run: |
mypy tests --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/inputs --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/multimodal --config-file pyproject.toml
mypy vllm/platforms --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy tests --follow-imports skip
mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip
mypy vllm/prompt_adapter --follow-imports skip
mypy vllm/spec_decode --follow-imports skip
mypy vllm/worker --follow-imports skip
mypy

View File

@ -96,23 +96,19 @@ echo 'vLLM yapf: Done'
# Run mypy
echo 'vLLM mypy:'
mypy tests --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/multimodal --config-file pyproject.toml
mypy vllm/prompt_adapter --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy tests --follow-imports skip
mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip
mypy vllm/prompt_adapter --follow-imports skip
mypy vllm/spec_decode --follow-imports skip
mypy vllm/worker --follow-imports skip
mypy
# If git diff returns a file that is in the skip list, the file may be checked anyway:

View File

@ -48,9 +48,23 @@ python_version = "3.8"
ignore_missing_imports = true
check_untyped_defs = true
follow_imports = "skip"
follow_imports = "silent"
files = "vllm"
# After fixing type errors resulting from follow_imports: "skip" -> "silent",
# move the directory here and remove it from format.sh and mypy.yaml
files = [
"vllm/*.py",
"vllm/adapter_commons",
"vllm/assets",
"vllm/inputs",
"vllm/logging",
"vllm/multimodal",
"vllm/platforms",
"vllm/server",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",

View File

@ -239,7 +239,7 @@ def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)

View File

@ -25,27 +25,33 @@ class ipex_ops:
x2 = x2.reshape(num, d)
return x1, x2
@staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.silu_mul(x1, x2, out)
@staticmethod
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "none")
@staticmethod
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
@staticmethod
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))
@staticmethod
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))
# TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
@staticmethod
def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
@ -78,12 +84,21 @@ class ipex_ops:
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v1(out, query.contiguous(),
key_cache.view_as(value_cache),
value_cache, head_mapping, scale,
block_tables, context_lens, block_size,
max_context_len, alibi_slopes)
torch.xpu.paged_attention_v1( # type: ignore
out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
@staticmethod
def paged_attention_v2(
out: torch.Tensor,
exp_sum: torch.Tensor,
@ -119,13 +134,24 @@ class ipex_ops:
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache, head_mapping, block_tables,
context_lens, scale, block_size,
max_context_len, alibi_slopes)
torch.xpu.paged_attention_v2( # type: ignore
out,
exp_sum,
max_logits,
tmp_out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
block_tables,
context_lens,
scale,
block_size,
max_context_len,
alibi_slopes,
)
@staticmethod
def rotary_embedding(
positions: torch.Tensor, # [batch_size, seq_len]
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
@ -158,6 +184,7 @@ class ipex_ops:
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)
@staticmethod
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
@ -189,17 +216,20 @@ class ipex_ops:
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)
@staticmethod
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
out.copy_(tmp)
@staticmethod
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
epsilon, True)
input.copy_(tmp)
@staticmethod
def varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
@ -222,6 +252,7 @@ class ipex_ops:
softmax_scale, zero_tensors,
is_causal, return_softmax, gen_)
@staticmethod
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
@ -240,8 +271,13 @@ class ipex_ops:
def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None:
torch.xpu.copy_blocks(key_caches, value_caches, block_mapping)
torch.xpu.copy_blocks( # type: ignore
key_caches,
value_caches,
block_mapping,
)
@staticmethod
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None:
torch.xpu.swap_blocks(src, dst, block_mapping)
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore

View File

@ -31,7 +31,7 @@ class AdapterLRUCache(LRUCache[T]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn
def _on_remove(self, key: Hashable, value: T):
def _on_remove(self, key: Hashable, value: Optional[T]):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
@ -59,46 +59,46 @@ class AdapterModelManager(ABC):
@property
@abstractmethod
def adapter_slots(self):
...
def adapter_slots(self) -> int:
raise NotImplementedError
@property
@abstractmethod
def capacity(self):
...
def capacity(self) -> int:
raise NotImplementedError
@abstractmethod
def activate_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError
@abstractmethod
def deactivate_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError
@abstractmethod
def add_adapter(self, adapter: Any) -> bool:
...
raise NotImplementedError
@abstractmethod
def set_adapter_mapping(self, mapping: Any) -> None:
...
raise NotImplementedError
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError
@abstractmethod
def remove_all_adapters(self):
...
def remove_all_adapters(self) -> None:
raise NotImplementedError
@abstractmethod
def get_adapter(self, adapter_id: int) -> Optional[Any]:
...
raise NotImplementedError
@abstractmethod
def list_adapters(self) -> Dict[int, Any]:
...
raise NotImplementedError
@abstractmethod
def pin_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError

View File

@ -1,19 +1,19 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass
class AdapterRequest:
class AdapterRequest(ABC):
"""
Base class for adapter requests.
"""
@property
@abstractmethod
def adapter_id(self):
...
def adapter_id(self) -> int:
raise NotImplementedError
def __post_init__(self):
def __post_init__(self) -> None:
if self.adapter_id < 1:
raise ValueError(f"id must be > 0, got {self.adapter_id}")

View File

@ -12,25 +12,25 @@ class AbstractWorkerManager(ABC):
@property
@abstractmethod
def is_enabled(self) -> bool:
...
raise NotImplementedError
@abstractmethod
def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None:
...
raise NotImplementedError
@abstractmethod
def add_adapter(self, adapter_request: Any) -> bool:
...
raise NotImplementedError
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError
@abstractmethod
def remove_all_adapters(self):
...
def remove_all_adapters(self) -> None:
raise NotImplementedError
@abstractmethod
def list_adapters(self) -> Set[int]:
...
raise NotImplementedError

View File

@ -724,7 +724,7 @@ class ParallelConfig:
backend)
self._verify_args()
self.rank = 0
self.rank: int = 0
@property
def use_ray(self) -> bool:
@ -850,6 +850,7 @@ class SchedulerConfig:
class DeviceConfig:
device: Optional[torch.device]
def __init__(self, device: str = "auto") -> None:
if device == "auto":

View File

@ -5,8 +5,6 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union
from transformers import PreTrainedTokenizer
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
@ -40,7 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
from vllm.transformers_utils.tokenizer_group import (AnyTokenizer,
BaseTokenizerGroup,
get_tokenizer_group)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
@ -477,13 +476,12 @@ class LLMEngine:
return self.tokenizer
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)

View File

@ -5,7 +5,6 @@ from http import HTTPStatus
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated
from vllm.config import ModelConfig
@ -30,6 +29,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
logger = init_logger(__name__)
@ -49,8 +49,6 @@ class LoRAModulePath:
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
EmbeddingRequest, TokenizeRequest]
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
class TextTokensPrompt(TypedDict):
prompt: str

View File

@ -4,9 +4,10 @@ import asyncio
import os
import signal
import sys
from typing import Optional
from typing import List, Optional
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import make_arg_parser
@ -63,15 +64,14 @@ def complete(model_name: str, client: OpenAI) -> None:
def chat(system_prompt: Optional[str], model_name: str,
client: OpenAI) -> None:
conversation = []
conversation: List[ChatCompletionMessageParam] = []
if system_prompt is not None:
conversation.append({"role": "system", "content": system_prompt})
print("Please enter a message for the chat model:")
while True:
input_message = input("> ")
message = {"role": "user", "content": input_message}
conversation.append(message)
conversation.append({"role": "user", "content": input_message})
chat_completion = client.chat.completions.create(model=model_name,
messages=conversation)
@ -79,7 +79,7 @@ def chat(system_prompt: Optional[str], model_name: str,
response_message = chat_completion.choices[0].message
output = response_message.content
conversation.append(response_message)
conversation.append(response_message) # type: ignore
print(output)

View File

@ -37,6 +37,8 @@ class Detokenizer:
The prompt logprobs with the decoded tokens.
"""
prms = seq_group.sampling_params
assert prms is not None
# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
# Only prompt, without the generated token.

View File

@ -2,10 +2,9 @@ from typing import Optional, Type
from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
from .tokenizer_group import TokenizerGroup
if ray:
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
@ -34,4 +33,4 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]
__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"]

View File

@ -1,11 +1,13 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, Union
from transformers import PreTrainedTokenizer
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
class BaseTokenizerGroup(ABC):
"""A group of tokenizers that can be used for LoRA adapters."""
@ -47,17 +49,17 @@ class BaseTokenizerGroup(ABC):
@abstractmethod
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get a tokenizer for a LoRA request."""
pass
@abstractmethod
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get a tokenizer for a LoRA request."""
pass

View File

@ -6,18 +6,16 @@ try:
from ray.exceptions import ActorDiedError
except ImportError:
# For older versions of Ray
from ray.exceptions import RayActorError as ActorDiedError
from ray.exceptions import RayActorError as ActorDiedError # type: ignore
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
from .tokenizer_group import TokenizerGroup
logger = init_logger(__name__)
@ -67,7 +65,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
**self._tokenizer_config, )
self._ray_tokenizer_group_cls = ray.remote(
self._worker_cls).options(**ray_actor_options)
self._worker_cls).options(**ray_actor_options) # type: ignore
self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
self._idle_actors: Optional[asyncio.Queue] = None
@ -83,8 +81,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
return len(self.tokenizer_actors)
def ping(self):
return ray.get(
[actor.ping.remote() for actor in self.tokenizer_actors])
return ray.get([
actor.ping.remote() # type: ignore
for actor in self.tokenizer_actors
])
def _ensure_queue_initialized(self):
if self._idle_actors is None:
@ -208,15 +208,15 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
return self._local_tokenizer_group.get_max_input_len(lora_request)
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self._local_tokenizer_group.get_lora_tokenizer(lora_request)
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return await self._local_tokenizer_group.get_lora_tokenizer_async(
lora_request)

View File

@ -1,16 +1,14 @@
from typing import List, Optional
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
get_lora_tokenizer_async,
get_tokenizer)
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.utils import LRUCache
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
class TokenizerGroup(BaseTokenizerGroup):
"""A group of tokenizers that can be used for LoRA adapters."""
@ -22,8 +20,8 @@ class TokenizerGroup(BaseTokenizerGroup):
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
capacity=max_num_seqs) if enable_lora else None
self.lora_tokenizers = LRUCache[AnyTokenizer](
capacity=max_num_seqs if enable_lora else 0)
@classmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
@ -41,7 +39,7 @@ class TokenizerGroup(BaseTokenizerGroup):
return self.max_input_length
def _raise_if_input_too_long(self,
encoded_tokens: List[str],
encoded_tokens: List[int],
lora_request: Optional[LoRARequest] = None):
input_length = len(encoded_tokens)
if lora_request:
@ -72,9 +70,9 @@ class TokenizerGroup(BaseTokenizerGroup):
return ret
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
@ -83,12 +81,12 @@ class TokenizerGroup(BaseTokenizerGroup):
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
return self.lora_tokenizers[lora_request.lora_int_id]
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
@ -97,4 +95,4 @@ class TokenizerGroup(BaseTokenizerGroup):
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
return self.lora_tokenizers[lora_request.lora_int_id]

View File

@ -94,8 +94,10 @@ class LRUCache(Generic[T]):
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: Hashable) -> Optional[T]:
return self.get(key)
def __getitem__(self, key: Hashable) -> T:
value = self.cache[key] # Raise KeyError if not exists
self.cache.move_to_end(key)
return value
def __setitem__(self, key: Hashable, value: T) -> None:
self.put(key, value)
@ -109,8 +111,9 @@ class LRUCache(Generic[T]):
def get(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
value: Optional[T]
if key in self.cache:
value: Optional[T] = self.cache[key]
value = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
@ -590,8 +593,8 @@ class CudaMemoryProfiler:
torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
elif is_xpu():
torch.xpu.reset_peak_memory_stats(self.device)
mem = torch.xpu.max_memory_allocated(self.device)
torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
return mem
def __enter__(self):