diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 1e31ff7373..cb125a7bfe 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -89,3 +89,6 @@ async def test_new_requests_event(): await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == old_step_calls + 1 + + engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True) + assert engine.get_tokenizer() is not None diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 65ab0c0634..5629d1a863 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -5,6 +5,8 @@ from functools import partial from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, AsyncIterator, Callable) +from transformers import PreTrainedTokenizer + from vllm.lora.request import LoRARequest from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -372,8 +374,11 @@ class AsyncLLMEngine: self.set_errored(exc) self._request_tracker.propagate_exception(exc) - def get_tokenizer(self): - return self.engine.tokenizer.tokenizer + async def get_tokenizer(self) -> "PreTrainedTokenizer": + if self.engine_use_ray: + return await self.engine.get_tokenizer.remote() + else: + return self.engine.get_tokenizer() def start_background_loop(self) -> None: """Start the background loop.""" diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8484014c9a..5b46d9db56 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -7,6 +7,8 @@ import importlib from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union) +from transformers import PreTrainedTokenizer + import vllm from vllm.lora.request import LoRARequest from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, @@ -163,7 +165,11 @@ class LLMEngine: # the closure used to initialize Ray worker actors raise RuntimeError("LLMEngine should not be pickled!") - def get_tokenizer_for_seq(self, sequence: Sequence): + def get_tokenizer(self) -> "PreTrainedTokenizer": + return self.tokenizer.get_lora_tokenizer() + + def get_tokenizer_for_seq(self, + sequence: Sequence) -> "PreTrainedTokenizer": return self.tokenizer.get_lora_tokenizer(sequence.lora_request) def _dispatch_worker(self): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index ba352f18f6..7d5603c85e 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -65,7 +65,7 @@ class OpenAIServingChat(OpenAIServing): lora_request = self._maybe_get_lora(request) guided_decode_logits_processor = ( await get_guided_decoding_logits_processor( - request, self.engine.get_tokenizer())) + request, await self.engine.get_tokenizer())) if guided_decode_logits_processor: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index a8244fd150..c673b2582c 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -126,7 +126,7 @@ class OpenAIServingCompletion(OpenAIServing): lora_request = self._maybe_get_lora(request) guided_decode_logit_processor = ( await get_guided_decoding_logits_processor( - request, self.engine.get_tokenizer())) + request, await self.engine.get_tokenizer())) if guided_decode_logit_processor is not None: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 6edc225cdf..2600ea2642 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -120,7 +120,8 @@ class TokenizerGroup: def get_lora_tokenizer( self, - lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": if not lora_request or not self.enable_lora: return self.tokenizer if lora_request.lora_int_id not in self.lora_tokenizers: @@ -133,7 +134,8 @@ class TokenizerGroup: async def get_lora_tokenizer_async( self, - lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + lora_request: Optional[LoRARequest] = None + ) -> "PreTrainedTokenizer": if not lora_request or not self.enable_lora: return self.tokenizer if lora_request.lora_int_id not in self.lora_tokenizers: