[BugFix] Fix get tokenizer when using ray (#3301)

This commit is contained in:
Roy 2024-03-11 10:17:16 +08:00 committed by GitHub
parent e4a28e5316
commit 9e8744a545
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 23 additions and 7 deletions

View File

@ -89,3 +89,6 @@ async def test_new_requests_event():
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1
engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
assert engine.get_tokenizer() is not None

View File

@ -5,6 +5,8 @@ from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator, Callable)
from transformers import PreTrainedTokenizer
from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
@ -372,8 +374,11 @@ class AsyncLLMEngine:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)
def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
async def get_tokenizer(self) -> "PreTrainedTokenizer":
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote()
else:
return self.engine.get_tokenizer()
def start_background_loop(self) -> None:
"""Start the background loop."""

View File

@ -7,6 +7,8 @@ import importlib
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union)
from transformers import PreTrainedTokenizer
import vllm
from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
@ -163,7 +165,11 @@ class LLMEngine:
# the closure used to initialize Ray worker actors
raise RuntimeError("LLMEngine should not be pickled!")
def get_tokenizer_for_seq(self, sequence: Sequence):
def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer()
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
def _dispatch_worker(self):

View File

@ -65,7 +65,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request = self._maybe_get_lora(request)
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer()))
request, await self.engine.get_tokenizer()))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []

View File

@ -126,7 +126,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request = self._maybe_get_lora(request)
guided_decode_logit_processor = (
await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer()))
request, await self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []

View File

@ -120,7 +120,8 @@ class TokenizerGroup:
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
@ -133,7 +134,8 @@ class TokenizerGroup:
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers: