mirror of https://github.com/vllm-project/vllm
[BugFix] Fix get tokenizer when using ray (#3301)
This commit is contained in:
parent
e4a28e5316
commit
9e8744a545
|
@ -89,3 +89,6 @@ async def test_new_requests_event():
|
|||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == old_step_calls + 1
|
||||
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
|
||||
assert engine.get_tokenizer() is not None
|
||||
|
|
|
@ -5,6 +5,8 @@ from functools import partial
|
|||
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||
Union, AsyncIterator, Callable)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
@ -372,8 +374,11 @@ class AsyncLLMEngine:
|
|||
self.set_errored(exc)
|
||||
self._request_tracker.propagate_exception(exc)
|
||||
|
||||
def get_tokenizer(self):
|
||||
return self.engine.tokenizer.tokenizer
|
||||
async def get_tokenizer(self) -> "PreTrainedTokenizer":
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_tokenizer.remote()
|
||||
else:
|
||||
return self.engine.get_tokenizer()
|
||||
|
||||
def start_background_loop(self) -> None:
|
||||
"""Start the background loop."""
|
||||
|
|
|
@ -7,6 +7,8 @@ import importlib
|
|||
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
|
@ -163,7 +165,11 @@ class LLMEngine:
|
|||
# the closure used to initialize Ray worker actors
|
||||
raise RuntimeError("LLMEngine should not be pickled!")
|
||||
|
||||
def get_tokenizer_for_seq(self, sequence: Sequence):
|
||||
def get_tokenizer(self) -> "PreTrainedTokenizer":
|
||||
return self.tokenizer.get_lora_tokenizer()
|
||||
|
||||
def get_tokenizer_for_seq(self,
|
||||
sequence: Sequence) -> "PreTrainedTokenizer":
|
||||
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
def _dispatch_worker(self):
|
||||
|
|
|
@ -65,7 +65,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||
lora_request = self._maybe_get_lora(request)
|
||||
guided_decode_logits_processor = (
|
||||
await get_guided_decoding_logits_processor(
|
||||
request, self.engine.get_tokenizer()))
|
||||
request, await self.engine.get_tokenizer()))
|
||||
if guided_decode_logits_processor:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
|
|
|
@ -126,7 +126,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||
lora_request = self._maybe_get_lora(request)
|
||||
guided_decode_logit_processor = (
|
||||
await get_guided_decoding_logits_processor(
|
||||
request, self.engine.get_tokenizer()))
|
||||
request, await self.engine.get_tokenizer()))
|
||||
if guided_decode_logit_processor is not None:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
|
|
|
@ -120,7 +120,8 @@ class TokenizerGroup:
|
|||
|
||||
def get_lora_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
if not lora_request or not self.enable_lora:
|
||||
return self.tokenizer
|
||||
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||
|
@ -133,7 +134,8 @@ class TokenizerGroup:
|
|||
|
||||
async def get_lora_tokenizer_async(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
if not lora_request or not self.enable_lora:
|
||||
return self.tokenizer
|
||||
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||
|
|
Loading…
Reference in New Issue