[Bugfix][Frontend] Guard against bad token ids (#9634)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde 2024-10-29 16:13:20 -05:00 committed by GitHub
parent 0ad216f575
commit 67bdf8e523
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 89 additions and 17 deletions

View File

@ -4,6 +4,12 @@ from vllm import LLM
def test_empty_prompt():
llm = LLM(model="gpt2")
llm = LLM(model="gpt2", enforce_eager=True)
with pytest.raises(ValueError, match='Prompt cannot be empty'):
llm.generate([""])
def test_out_of_vocab_token():
llm = LLM(model="gpt2", enforce_eager=True)
with pytest.raises(ValueError, match='out of vocabulary'):
llm.generate({"prompt_token_ids": [999999]})

View File

@ -157,15 +157,15 @@ async def test_added_lora_tokens(client: openai.AsyncOpenAI):
@pytest.mark.asyncio
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
# Added tokens should not appear in tokenized prompt
assert "vllm" not in completion.choices[0].text
with pytest.raises(openai.BadRequestError, match="out of vocabulary"):
# Added tokens should be rejected by the base model
await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
@pytest.mark.asyncio

View File

@ -20,3 +20,18 @@ async def test_empty_prompt():
prompt="",
max_tokens=5,
temperature=0.0)
@pytest.mark.asyncio
async def test_out_of_vocab_token_ids():
model_name = "gpt2"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
with pytest.raises(openai.BadRequestError,
match=re.compile('.*out of vocabulary.*')):
await client.completions.create(model=model_name,
prompt=[999999],
max_tokens=5,
temperature=0.0)

View File

@ -412,6 +412,12 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def get_tokenizer_async(self,
lora_request: Optional[LoRARequest] = None
) -> AnyTokenizer:
return await (
self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))
@overload # DEPRECATED
async def add_request_async(
self,
@ -472,6 +478,10 @@ class _AsyncLLMEngine(LLMEngine):
if arrival_time is None:
arrival_time = time.time()
if self.tokenizer is not None:
tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer)
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
request_id=request_id,
@ -488,7 +498,7 @@ class _AsyncLLMEngine(LLMEngine):
# implementation in the LLMEngine
params = await build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=self.get_tokenizer(lora_request),
tokenizer=await self.get_tokenizer_async(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend)
@ -715,8 +725,7 @@ class AsyncLLMEngine(EngineClient):
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return await (self.engine.get_tokenizer_group().
get_lora_tokenizer_async(lora_request))
return await self.engine.get_tokenizer_async(lora_request)
def start_background_loop(self) -> None:
"""Start the background loop."""

View File

@ -10,7 +10,7 @@ from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload
import torch
from typing_extensions import TypeVar
from typing_extensions import TypeIs, TypeVar
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
@ -32,7 +32,8 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputRegistry, PromptType)
EncoderDecoderInputs, InputRegistry, PromptType,
TokensPrompt)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
@ -667,7 +668,7 @@ class LLMEngine:
)
return None
self._validate_model_inputs(processed_inputs)
self._validate_model_inputs(processed_inputs, lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
@ -829,6 +830,11 @@ class LLMEngine:
if arrival_time is None:
arrival_time = time.time()
if self.tokenizer is not None:
self._validate_token_prompt(
prompt,
tokenizer=self.get_tokenizer(lora_request=lora_request))
preprocessed_inputs = self.input_preprocessor.preprocess(
prompt,
request_id=request_id,
@ -855,6 +861,31 @@ class LLMEngine:
priority=priority,
)
def _validate_token_prompt(self, prompt: PromptType,
tokenizer: AnyTokenizer):
# Guard against out-of-vocab tokens.
# For some tokenizers, tokenizer.decode will happily return empty text
# for token ids that are out of vocab, and we don't detect token ids
# that are greater than the max token id before running the model.
# However, these token ids will later crash a cuda kernel at runtime
# with an index out of bounds error. This will crash the entire engine.
# This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary.
if self._is_token_prompt(prompt):
prompt_ids = prompt["prompt_token_ids"]
if len(prompt_ids) == 0:
# Empty prompt check is handled later
return
max_input_id = max(prompt_ids)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
@staticmethod
def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
def _create_sequence_group_with_sampling(
self,
request_id: str,
@ -1942,7 +1973,8 @@ class LLMEngine:
return self.input_preprocessor.is_encoder_decoder_model()
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs]):
EncoderDecoderInputs],
lora_request: Optional[LoRARequest]):
if self.model_config.is_multimodal_model:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length

View File

@ -35,6 +35,7 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
tokenizer.all_special_tokens_extended)
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_len = len(tokenizer)
max_token_id = max(tokenizer.get_vocab().values())
class CachedTokenizer(tokenizer.__class__): # type: ignore
@ -50,6 +51,10 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
def all_special_tokens_extended(self):
return tokenizer_all_special_tokens_extended
@property
def max_token_id(self):
return max_token_id
def __len__(self):
return tokenizer_len

View File

@ -85,6 +85,7 @@ class MistralTokenizer:
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
self.tokenizer = tokenizer_
self._max_token_id = max(self._vocab.values())
@classmethod
def from_pretrained(cls,
@ -158,6 +159,10 @@ class MistralTokenizer:
def vocab_size(self) -> int:
return len(self._vocab)
@property
def max_token_id(self) -> int:
return self._max_token_id
def __len__(self) -> int:
return self.vocab_size