mirror of https://github.com/vllm-project/vllm
[Bugfix][Frontend] Guard against bad token ids (#9634)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
0ad216f575
commit
67bdf8e523
|
@ -4,6 +4,12 @@ from vllm import LLM
|
|||
|
||||
|
||||
def test_empty_prompt():
|
||||
llm = LLM(model="gpt2")
|
||||
llm = LLM(model="gpt2", enforce_eager=True)
|
||||
with pytest.raises(ValueError, match='Prompt cannot be empty'):
|
||||
llm.generate([""])
|
||||
|
||||
|
||||
def test_out_of_vocab_token():
|
||||
llm = LLM(model="gpt2", enforce_eager=True)
|
||||
with pytest.raises(ValueError, match='out of vocabulary'):
|
||||
llm.generate({"prompt_token_ids": [999999]})
|
||||
|
|
|
@ -157,15 +157,15 @@ async def test_added_lora_tokens(client: openai.AsyncOpenAI):
|
|||
@pytest.mark.asyncio
|
||||
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 32000, 32001, 32002],
|
||||
echo=True,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
# Added tokens should not appear in tokenized prompt
|
||||
assert "vllm" not in completion.choices[0].text
|
||||
with pytest.raises(openai.BadRequestError, match="out of vocabulary"):
|
||||
# Added tokens should be rejected by the base model
|
||||
await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 32000, 32001, 32002],
|
||||
echo=True,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -20,3 +20,18 @@ async def test_empty_prompt():
|
|||
prompt="",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_out_of_vocab_token_ids():
|
||||
model_name = "gpt2"
|
||||
server_args = ["--enforce-eager"]
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
|
||||
with pytest.raises(openai.BadRequestError,
|
||||
match=re.compile('.*out of vocabulary.*')):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt=[999999],
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
|
|
@ -412,6 +412,12 @@ class _AsyncLLMEngine(LLMEngine):
|
|||
"""Stop the remote worker execution loop."""
|
||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||
|
||||
async def get_tokenizer_async(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> AnyTokenizer:
|
||||
return await (
|
||||
self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))
|
||||
|
||||
@overload # DEPRECATED
|
||||
async def add_request_async(
|
||||
self,
|
||||
|
@ -472,6 +478,10 @@ class _AsyncLLMEngine(LLMEngine):
|
|||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
if self.tokenizer is not None:
|
||||
tokenizer = await self.get_tokenizer_async(lora_request)
|
||||
self._validate_token_prompt(prompt, tokenizer=tokenizer)
|
||||
|
||||
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
|
@ -488,7 +498,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||
# implementation in the LLMEngine
|
||||
params = await build_guided_decoding_logits_processor_async(
|
||||
sampling_params=params,
|
||||
tokenizer=self.get_tokenizer(lora_request),
|
||||
tokenizer=await self.get_tokenizer_async(lora_request),
|
||||
default_guided_backend=self.decoding_config.
|
||||
guided_decoding_backend)
|
||||
|
||||
|
@ -715,8 +725,7 @@ class AsyncLLMEngine(EngineClient):
|
|||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
return await (self.engine.get_tokenizer_group().
|
||||
get_lora_tokenizer_async(lora_request))
|
||||
return await self.engine.get_tokenizer_async(lora_request)
|
||||
|
||||
def start_background_loop(self) -> None:
|
||||
"""Start the background loop."""
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import Sequence as GenericSequence
|
|||
from typing import Set, Type, Union, cast, overload
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeVar
|
||||
from typing_extensions import TypeIs, TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
|
@ -32,7 +32,8 @@ from vllm.executor.executor_base import ExecutorBase
|
|||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
||||
EncoderDecoderInputs, InputRegistry, PromptType)
|
||||
EncoderDecoderInputs, InputRegistry, PromptType,
|
||||
TokensPrompt)
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import get_bad_words_logits_processors
|
||||
|
@ -667,7 +668,7 @@ class LLMEngine:
|
|||
)
|
||||
return None
|
||||
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
self._validate_model_inputs(processed_inputs, lora_request)
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seq_id = next(self.seq_counter)
|
||||
|
@ -829,6 +830,11 @@ class LLMEngine:
|
|||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self._validate_token_prompt(
|
||||
prompt,
|
||||
tokenizer=self.get_tokenizer(lora_request=lora_request))
|
||||
|
||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
|
@ -855,6 +861,31 @@ class LLMEngine:
|
|||
priority=priority,
|
||||
)
|
||||
|
||||
def _validate_token_prompt(self, prompt: PromptType,
|
||||
tokenizer: AnyTokenizer):
|
||||
# Guard against out-of-vocab tokens.
|
||||
# For some tokenizers, tokenizer.decode will happily return empty text
|
||||
# for token ids that are out of vocab, and we don't detect token ids
|
||||
# that are greater than the max token id before running the model.
|
||||
# However, these token ids will later crash a cuda kernel at runtime
|
||||
# with an index out of bounds error. This will crash the entire engine.
|
||||
# This needs to happen before multimodal input pre-processing, which
|
||||
# may add dummy <image> tokens that aren't part of the tokenizer's
|
||||
# vocabulary.
|
||||
if self._is_token_prompt(prompt):
|
||||
prompt_ids = prompt["prompt_token_ids"]
|
||||
if len(prompt_ids) == 0:
|
||||
# Empty prompt check is handled later
|
||||
return
|
||||
max_input_id = max(prompt_ids)
|
||||
if max_input_id > tokenizer.max_token_id:
|
||||
raise ValueError(
|
||||
"Token id {} is out of vocabulary".format(max_input_id))
|
||||
|
||||
@staticmethod
|
||||
def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
|
||||
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
|
||||
def _create_sequence_group_with_sampling(
|
||||
self,
|
||||
request_id: str,
|
||||
|
@ -1942,7 +1973,8 @@ class LLMEngine:
|
|||
return self.input_preprocessor.is_encoder_decoder_model()
|
||||
|
||||
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
|
||||
EncoderDecoderInputs]):
|
||||
EncoderDecoderInputs],
|
||||
lora_request: Optional[LoRARequest]):
|
||||
if self.model_config.is_multimodal_model:
|
||||
# For encoder-decoder multimodal models, the max_prompt_len
|
||||
# restricts the decoder prompt length
|
||||
|
|
|
@ -35,6 +35,7 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
|||
tokenizer.all_special_tokens_extended)
|
||||
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
|
||||
tokenizer_len = len(tokenizer)
|
||||
max_token_id = max(tokenizer.get_vocab().values())
|
||||
|
||||
class CachedTokenizer(tokenizer.__class__): # type: ignore
|
||||
|
||||
|
@ -50,6 +51,10 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
|||
def all_special_tokens_extended(self):
|
||||
return tokenizer_all_special_tokens_extended
|
||||
|
||||
@property
|
||||
def max_token_id(self):
|
||||
return max_token_id
|
||||
|
||||
def __len__(self):
|
||||
return tokenizer_len
|
||||
|
||||
|
|
|
@ -85,6 +85,7 @@ class MistralTokenizer:
|
|||
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
|
||||
|
||||
self.tokenizer = tokenizer_
|
||||
self._max_token_id = max(self._vocab.values())
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
|
@ -158,6 +159,10 @@ class MistralTokenizer:
|
|||
def vocab_size(self) -> int:
|
||||
return len(self._vocab)
|
||||
|
||||
@property
|
||||
def max_token_id(self) -> int:
|
||||
return self._max_token_id
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.vocab_size
|
||||
|
||||
|
|
Loading…
Reference in New Issue