mirror of https://github.com/vllm-project/vllm
[BugFix] Fix get tokenizer when using ray (#3301)
This commit is contained in:
parent
e4a28e5316
commit
9e8744a545
|
@ -89,3 +89,6 @@ async def test_new_requests_event():
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
assert engine.engine.add_request_calls == 3
|
assert engine.engine.add_request_calls == 3
|
||||||
assert engine.engine.step_calls == old_step_calls + 1
|
assert engine.engine.step_calls == old_step_calls + 1
|
||||||
|
|
||||||
|
engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
|
||||||
|
assert engine.get_tokenizer() is not None
|
||||||
|
|
|
@ -5,6 +5,8 @@ from functools import partial
|
||||||
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||||
Union, AsyncIterator, Callable)
|
Union, AsyncIterator, Callable)
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
@ -372,8 +374,11 @@ class AsyncLLMEngine:
|
||||||
self.set_errored(exc)
|
self.set_errored(exc)
|
||||||
self._request_tracker.propagate_exception(exc)
|
self._request_tracker.propagate_exception(exc)
|
||||||
|
|
||||||
def get_tokenizer(self):
|
async def get_tokenizer(self) -> "PreTrainedTokenizer":
|
||||||
return self.engine.tokenizer.tokenizer
|
if self.engine_use_ray:
|
||||||
|
return await self.engine.get_tokenizer.remote()
|
||||||
|
else:
|
||||||
|
return self.engine.get_tokenizer()
|
||||||
|
|
||||||
def start_background_loop(self) -> None:
|
def start_background_loop(self) -> None:
|
||||||
"""Start the background loop."""
|
"""Start the background loop."""
|
||||||
|
|
|
@ -7,6 +7,8 @@ import importlib
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
||||||
Union)
|
Union)
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
|
@ -163,7 +165,11 @@ class LLMEngine:
|
||||||
# the closure used to initialize Ray worker actors
|
# the closure used to initialize Ray worker actors
|
||||||
raise RuntimeError("LLMEngine should not be pickled!")
|
raise RuntimeError("LLMEngine should not be pickled!")
|
||||||
|
|
||||||
def get_tokenizer_for_seq(self, sequence: Sequence):
|
def get_tokenizer(self) -> "PreTrainedTokenizer":
|
||||||
|
return self.tokenizer.get_lora_tokenizer()
|
||||||
|
|
||||||
|
def get_tokenizer_for_seq(self,
|
||||||
|
sequence: Sequence) -> "PreTrainedTokenizer":
|
||||||
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
||||||
|
|
||||||
def _dispatch_worker(self):
|
def _dispatch_worker(self):
|
||||||
|
|
|
@ -65,7 +65,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||||
lora_request = self._maybe_get_lora(request)
|
lora_request = self._maybe_get_lora(request)
|
||||||
guided_decode_logits_processor = (
|
guided_decode_logits_processor = (
|
||||||
await get_guided_decoding_logits_processor(
|
await get_guided_decoding_logits_processor(
|
||||||
request, self.engine.get_tokenizer()))
|
request, await self.engine.get_tokenizer()))
|
||||||
if guided_decode_logits_processor:
|
if guided_decode_logits_processor:
|
||||||
if sampling_params.logits_processors is None:
|
if sampling_params.logits_processors is None:
|
||||||
sampling_params.logits_processors = []
|
sampling_params.logits_processors = []
|
||||||
|
|
|
@ -126,7 +126,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||||
lora_request = self._maybe_get_lora(request)
|
lora_request = self._maybe_get_lora(request)
|
||||||
guided_decode_logit_processor = (
|
guided_decode_logit_processor = (
|
||||||
await get_guided_decoding_logits_processor(
|
await get_guided_decoding_logits_processor(
|
||||||
request, self.engine.get_tokenizer()))
|
request, await self.engine.get_tokenizer()))
|
||||||
if guided_decode_logit_processor is not None:
|
if guided_decode_logit_processor is not None:
|
||||||
if sampling_params.logits_processors is None:
|
if sampling_params.logits_processors is None:
|
||||||
sampling_params.logits_processors = []
|
sampling_params.logits_processors = []
|
||||||
|
|
|
@ -120,7 +120,8 @@ class TokenizerGroup:
|
||||||
|
|
||||||
def get_lora_tokenizer(
|
def get_lora_tokenizer(
|
||||||
self,
|
self,
|
||||||
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
|
lora_request: Optional[LoRARequest] = None
|
||||||
|
) -> "PreTrainedTokenizer":
|
||||||
if not lora_request or not self.enable_lora:
|
if not lora_request or not self.enable_lora:
|
||||||
return self.tokenizer
|
return self.tokenizer
|
||||||
if lora_request.lora_int_id not in self.lora_tokenizers:
|
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||||
|
@ -133,7 +134,8 @@ class TokenizerGroup:
|
||||||
|
|
||||||
async def get_lora_tokenizer_async(
|
async def get_lora_tokenizer_async(
|
||||||
self,
|
self,
|
||||||
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
|
lora_request: Optional[LoRARequest] = None
|
||||||
|
) -> "PreTrainedTokenizer":
|
||||||
if not lora_request or not self.enable_lora:
|
if not lora_request or not self.enable_lora:
|
||||||
return self.tokenizer
|
return self.tokenizer
|
||||||
if lora_request.lora_int_id not in self.lora_tokenizers:
|
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||||
|
|
Loading…
Reference in New Issue