mirror of https://github.com/vllm-project/vllm
[Core][Performance] Add XGrammar support for guided decoding and set it as default (#10785)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Signed-off-by: mgoin <michael@neuralmagic.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
3257d449fa
commit
9323a3153b
|
@ -178,6 +178,7 @@ autodoc_mock_imports = [
|
||||||
"tensorizer",
|
"tensorizer",
|
||||||
"pynvml",
|
"pynvml",
|
||||||
"outlines",
|
"outlines",
|
||||||
|
"xgrammar,"
|
||||||
"librosa",
|
"librosa",
|
||||||
"soundfile",
|
"soundfile",
|
||||||
"gguf",
|
"gguf",
|
||||||
|
|
|
@ -19,6 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
|
||||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||||
lm-format-enforcer >= 0.10.9, < 0.11
|
lm-format-enforcer >= 0.10.9, < 0.11
|
||||||
outlines >= 0.0.43, < 0.1
|
outlines >= 0.0.43, < 0.1
|
||||||
|
xgrammar
|
||||||
typing_extensions >= 4.10
|
typing_extensions >= 4.10
|
||||||
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
||||||
partial-json-parser # used for parsing partial JSON outputs
|
partial-json-parser # used for parsing partial JSON outputs
|
||||||
|
|
|
@ -159,3 +159,30 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=True,
|
use_tqdm=True,
|
||||||
guided_options_request=dict(guided_regex=sample_regex))
|
guided_options_request=dict(guided_regex=sample_regex))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
def test_guided_json_object(llm):
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
|
max_tokens=100,
|
||||||
|
guided_decoding=GuidedDecodingParams(json_object=True))
|
||||||
|
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts=("Generate a JSON object describing a person with name "
|
||||||
|
"and age for John Smith who is 31 years old."),
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
assert generated_text is not None
|
||||||
|
|
||||||
|
# Parse to verify it is valid JSON
|
||||||
|
parsed_json = json.loads(generated_text)
|
||||||
|
assert isinstance(parsed_json, dict)
|
||||||
|
|
|
@ -36,7 +36,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
|
@pytest.mark.parametrize("backend",
|
||||||
|
["outlines", "lm-format-enforcer", "xgrammar"])
|
||||||
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
|
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
|
||||||
sample_json_schema):
|
sample_json_schema):
|
||||||
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
||||||
|
|
|
@ -1789,15 +1789,15 @@ class PoolerConfig:
|
||||||
|
|
||||||
step_tag_id: Optional[int] = None
|
step_tag_id: Optional[int] = None
|
||||||
"""
|
"""
|
||||||
If set, only the score corresponding to the ``step_tag_id`` in the
|
If set, only the score corresponding to the ``step_tag_id`` in the
|
||||||
generated sentence should be returned. Otherwise, the scores for all tokens
|
generated sentence should be returned. Otherwise, the scores for all tokens
|
||||||
are returned.
|
are returned.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
returned_token_ids: Optional[List[int]] = None
|
returned_token_ids: Optional[List[int]] = None
|
||||||
"""
|
"""
|
||||||
A list of indices for the vocabulary dimensions to be extracted,
|
A list of indices for the vocabulary dimensions to be extracted,
|
||||||
such as the token IDs of ``good_token`` and ``bad_token`` in the
|
such as the token IDs of ``good_token`` and ``bad_token`` in the
|
||||||
``math-shepherd-mistral-7b-prm`` model.
|
``math-shepherd-mistral-7b-prm`` model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -2031,11 +2031,12 @@ def get_served_model_name(model: str,
|
||||||
class DecodingConfig:
|
class DecodingConfig:
|
||||||
"""Dataclass which contains the decoding strategy of the engine"""
|
"""Dataclass which contains the decoding strategy of the engine"""
|
||||||
|
|
||||||
# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
|
# Which guided decoding algo to use.
|
||||||
guided_decoding_backend: str = 'outlines'
|
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
|
||||||
|
guided_decoding_backend: str = 'xgrammar'
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
valid_guided_backends = ['outlines', 'lm-format-enforcer']
|
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
|
||||||
backend = self.guided_decoding_backend
|
backend = self.guided_decoding_backend
|
||||||
if backend not in valid_guided_backends:
|
if backend not in valid_guided_backends:
|
||||||
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
|
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
|
||||||
|
@ -2222,7 +2223,7 @@ class CompilationConfig(BaseModel):
|
||||||
from Python, functions can also be passed directly via Python object
|
from Python, functions can also be passed directly via Python object
|
||||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
|
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
|
||||||
- custom inductor passes: see PassConfig for more details
|
- custom inductor passes: see PassConfig for more details
|
||||||
|
|
||||||
Why we have different sizes for cudagraph and inductor:
|
Why we have different sizes for cudagraph and inductor:
|
||||||
- cudagraph: a cudagraph captured for a specific size can only be used
|
- cudagraph: a cudagraph captured for a specific size can only be used
|
||||||
for the same size. We need to capture all the sizes we want to use.
|
for the same size. We need to capture all the sizes we want to use.
|
||||||
|
|
|
@ -168,7 +168,7 @@ class EngineArgs:
|
||||||
scheduler_delay_factor: float = 0.0
|
scheduler_delay_factor: float = 0.0
|
||||||
enable_chunked_prefill: Optional[bool] = None
|
enable_chunked_prefill: Optional[bool] = None
|
||||||
|
|
||||||
guided_decoding_backend: str = 'outlines'
|
guided_decoding_backend: str = 'xgrammar'
|
||||||
# Speculative decoding configuration.
|
# Speculative decoding configuration.
|
||||||
speculative_model: Optional[str] = None
|
speculative_model: Optional[str] = None
|
||||||
speculative_model_quantization: Optional[str] = None
|
speculative_model_quantization: Optional[str] = None
|
||||||
|
@ -364,11 +364,12 @@ class EngineArgs:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--guided-decoding-backend',
|
'--guided-decoding-backend',
|
||||||
type=str,
|
type=str,
|
||||||
default='outlines',
|
default='xgrammar',
|
||||||
choices=['outlines', 'lm-format-enforcer'],
|
choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
|
||||||
help='Which engine will be used for guided decoding'
|
help='Which engine will be used for guided decoding'
|
||||||
' (JSON schema / regex etc) by default. Currently support '
|
' (JSON schema / regex etc) by default. Currently support '
|
||||||
'https://github.com/outlines-dev/outlines and '
|
'https://github.com/outlines-dev/outlines,'
|
||||||
|
'https://github.com/mlc-ai/xgrammar, and '
|
||||||
'https://github.com/noamgat/lm-format-enforcer.'
|
'https://github.com/noamgat/lm-format-enforcer.'
|
||||||
' Can be overridden per request via guided_decoding_backend'
|
' Can be overridden per request via guided_decoding_backend'
|
||||||
' parameter.')
|
' parameter.')
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -507,7 +508,8 @@ class _AsyncLLMEngine(LLMEngine):
|
||||||
sampling_params=params,
|
sampling_params=params,
|
||||||
tokenizer=await self.get_tokenizer_async(lora_request),
|
tokenizer=await self.get_tokenizer_async(lora_request),
|
||||||
default_guided_backend=self.decoding_config.
|
default_guided_backend=self.decoding_config.
|
||||||
guided_decoding_backend)
|
guided_decoding_backend,
|
||||||
|
model_config=self.model_config)
|
||||||
|
|
||||||
self._add_processed_request(
|
self._add_processed_request(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
|
@ -528,22 +530,30 @@ class _AsyncLLMEngine(LLMEngine):
|
||||||
|
|
||||||
async def build_guided_decoding_logits_processor_async(
|
async def build_guided_decoding_logits_processor_async(
|
||||||
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
|
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
|
||||||
default_guided_backend: str) -> SamplingParams:
|
default_guided_backend: str,
|
||||||
|
model_config: ModelConfig) -> SamplingParams:
|
||||||
"""Constructs logits processors based on the guided_decoding,
|
"""Constructs logits processors based on the guided_decoding,
|
||||||
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
||||||
those fields and adds the constructed logits processors to the
|
those fields and adds the constructed logits processors to the
|
||||||
logits_processors field. Modifies sampling params in-place and returns
|
logits_processors field. Modifies sampling params in-place and returns
|
||||||
the modified sampling params."""
|
the modified sampling params."""
|
||||||
if (guided_decoding := sampling_params.guided_decoding) is None:
|
if sampling_params.guided_decoding is None:
|
||||||
return sampling_params
|
return sampling_params
|
||||||
|
|
||||||
|
# Defensively copy sampling params since guided decoding logits
|
||||||
|
# processors can have different state for each request
|
||||||
|
sampling_params = copy.copy(sampling_params)
|
||||||
|
guided_decoding = sampling_params.guided_decoding
|
||||||
|
|
||||||
logger.debug("Building guided decoding logits processor. "
|
logger.debug("Building guided decoding logits processor. "
|
||||||
"Params: %s", guided_decoding)
|
"Params: %s", guided_decoding)
|
||||||
|
|
||||||
guided_decoding.backend = guided_decoding.backend or default_guided_backend
|
guided_decoding.backend = guided_decoding.backend or default_guided_backend
|
||||||
|
|
||||||
processor = await get_guided_decoding_logits_processor(
|
processor = await get_guided_decoding_logits_processor(
|
||||||
guided_params=guided_decoding, tokenizer=tokenizer)
|
guided_params=guided_decoding,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model_config=model_config)
|
||||||
|
|
||||||
if processor:
|
if processor:
|
||||||
if sampling_params.logits_processors is None:
|
if sampling_params.logits_processors is None:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import copy
|
||||||
import time
|
import time
|
||||||
from collections import Counter as collectionsCounter
|
from collections import Counter as collectionsCounter
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
@ -1024,9 +1025,9 @@ class LLMEngine:
|
||||||
This function updates num_computed_tokens for prompt sequences
|
This function updates num_computed_tokens for prompt sequences
|
||||||
when Multi-Step is enabled.
|
when Multi-Step is enabled.
|
||||||
|
|
||||||
seq_group: SequenceGroup to update the num_computed_tokens for.
|
seq_group: SequenceGroup to update the num_computed_tokens for.
|
||||||
seq_group_meta: Metadata of the given SequenceGroup.
|
seq_group_meta: Metadata of the given SequenceGroup.
|
||||||
is_first_step_output: Optional[bool] -
|
is_first_step_output: Optional[bool] -
|
||||||
When available, is_first_step_output indicates if the appended
|
When available, is_first_step_output indicates if the appended
|
||||||
output token is the output of the first-step in multi-step.
|
output token is the output of the first-step in multi-step.
|
||||||
A value of None indicates that outputs from all steps in
|
A value of None indicates that outputs from all steps in
|
||||||
|
@ -2036,7 +2037,11 @@ class LLMEngine:
|
||||||
|
|
||||||
logits_processors = []
|
logits_processors = []
|
||||||
|
|
||||||
if (guided_decoding := sampling_params.guided_decoding) is not None:
|
if sampling_params.guided_decoding is not None:
|
||||||
|
# Defensively copy sampling params since guided decoding logits
|
||||||
|
# processors can have different state for each request
|
||||||
|
sampling_params = copy.copy(sampling_params)
|
||||||
|
guided_decoding = sampling_params.guided_decoding
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Building guided decoding logits processor in "
|
"Building guided decoding logits processor in "
|
||||||
|
@ -2047,7 +2052,9 @@ class LLMEngine:
|
||||||
self.decoding_config.guided_decoding_backend
|
self.decoding_config.guided_decoding_backend
|
||||||
|
|
||||||
processor = get_local_guided_decoding_logits_processor(
|
processor = get_local_guided_decoding_logits_processor(
|
||||||
guided_params=guided_decoding, tokenizer=tokenizer)
|
guided_params=guided_decoding,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model_config=self.model_config)
|
||||||
if processor:
|
if processor:
|
||||||
logits_processors.append(processor)
|
logits_processors.append(processor)
|
||||||
|
|
||||||
|
|
|
@ -474,8 +474,8 @@ class MQLLMEngineClient(EngineClient):
|
||||||
trace_headers: OpenTelemetry trace headers.
|
trace_headers: OpenTelemetry trace headers.
|
||||||
prompt_adapter_request: Prompt Adapter request to use
|
prompt_adapter_request: Prompt Adapter request to use
|
||||||
for generation, if any.
|
for generation, if any.
|
||||||
priority: Priority of the request (lower means earlier handling).
|
priority: Priority of the request (lower means earlier handling).
|
||||||
Any priority other than 0 will lead to an error if the
|
Any priority other than 0 will lead to an error if the
|
||||||
scheduling policy is not "priority".
|
scheduling policy is not "priority".
|
||||||
"""
|
"""
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
|
@ -589,6 +589,7 @@ class MQLLMEngineClient(EngineClient):
|
||||||
default_guided_backend=(self.decoding_config.guided_decoding_backend
|
default_guided_backend=(self.decoding_config.guided_decoding_backend
|
||||||
if self.decoding_config
|
if self.decoding_config
|
||||||
else DecodingConfig.guided_decoding_backend),
|
else DecodingConfig.guided_decoding_backend),
|
||||||
|
model_config=self.model_config
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1) Create output queue for this requests.
|
# 1) Create output queue for this requests.
|
||||||
|
|
|
@ -1,14 +1,54 @@
|
||||||
from typing import Optional
|
from __future__ import annotations
|
||||||
|
|
||||||
from vllm.logits_process import LogitsProcessor
|
from typing import TYPE_CHECKING
|
||||||
from vllm.sampling_params import GuidedDecodingParams
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.logits_process import LogitsProcessor
|
||||||
|
from vllm.sampling_params import GuidedDecodingParams
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_backend_fallback(
|
||||||
|
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
|
||||||
|
# lm-format-enforce doesn't support grammar, fallback to xgrammar
|
||||||
|
if (guided_params.backend == "lm-format-enforcer"
|
||||||
|
and guided_params.grammar is not None):
|
||||||
|
logger.warning(
|
||||||
|
"lm-format-enforcer does not support grammar guided decoding. "
|
||||||
|
"Falling back to use xgrammar instead.")
|
||||||
|
guided_params.backend = "xgrammar"
|
||||||
|
|
||||||
|
if guided_params.backend == "xgrammar":
|
||||||
|
# xgrammar doesn't support regex or choice, fallback to outlines
|
||||||
|
if guided_params.regex is not None or guided_params.choice is not None:
|
||||||
|
logger.warning(
|
||||||
|
"xgrammar only supports json or grammar guided decoding. "
|
||||||
|
"Falling back to use outlines instead.")
|
||||||
|
guided_params.backend = "outlines"
|
||||||
|
|
||||||
|
# xgrammar only supports EBNF grammars and uses the GBNF format
|
||||||
|
# https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||||
|
elif (guided_params.grammar is not None
|
||||||
|
and "::=" not in guided_params.grammar):
|
||||||
|
logger.warning("xgrammar only supports EBNF grammars. "
|
||||||
|
"Falling back to use outlines instead.")
|
||||||
|
guided_params.backend = "outlines"
|
||||||
|
|
||||||
|
return guided_params
|
||||||
|
|
||||||
|
|
||||||
async def get_guided_decoding_logits_processor(
|
async def get_guided_decoding_logits_processor(
|
||||||
guided_params: GuidedDecodingParams,
|
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
|
||||||
tokenizer) -> Optional[LogitsProcessor]:
|
model_config: ModelConfig) -> LogitsProcessor | None:
|
||||||
|
guided_params = maybe_backend_fallback(guided_params)
|
||||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||||
if guided_params.backend == 'outlines' or guided_params.grammar:
|
if guided_params.backend == 'outlines':
|
||||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||||
get_outlines_guided_decoding_logits_processor)
|
get_outlines_guided_decoding_logits_processor)
|
||||||
|
@ -19,17 +59,23 @@ async def get_guided_decoding_logits_processor(
|
||||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||||
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||||
guided_params, tokenizer)
|
guided_params, tokenizer)
|
||||||
|
if guided_params.backend == 'xgrammar':
|
||||||
|
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
||||||
|
get_local_xgrammar_guided_decoding_logits_processor)
|
||||||
|
return get_local_xgrammar_guided_decoding_logits_processor(
|
||||||
|
guided_params, tokenizer, model_config)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
||||||
"Must be one of 'outlines, 'lm-format-enforcer'")
|
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
|
||||||
|
|
||||||
|
|
||||||
def get_local_guided_decoding_logits_processor(
|
def get_local_guided_decoding_logits_processor(
|
||||||
guided_params: GuidedDecodingParams,
|
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
|
||||||
tokenizer) -> Optional[LogitsProcessor]:
|
model_config: ModelConfig) -> LogitsProcessor | None:
|
||||||
|
guided_params = maybe_backend_fallback(guided_params)
|
||||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||||
if guided_params.backend == 'outlines' or guided_params.grammar:
|
if guided_params.backend == 'outlines':
|
||||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||||
get_local_outlines_guided_decoding_logits_processor)
|
get_local_outlines_guided_decoding_logits_processor)
|
||||||
|
@ -40,7 +86,12 @@ def get_local_guided_decoding_logits_processor(
|
||||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||||
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||||
guided_params, tokenizer)
|
guided_params, tokenizer)
|
||||||
|
if guided_params.backend == 'xgrammar':
|
||||||
|
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
|
||||||
|
get_local_xgrammar_guided_decoding_logits_processor)
|
||||||
|
return get_local_xgrammar_guided_decoding_logits_processor(
|
||||||
|
guided_params, tokenizer, model_config)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
||||||
"Must be one of 'outlines, 'lm-format-enforcer'")
|
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
|
||||||
|
|
|
@ -0,0 +1,251 @@
|
||||||
|
# noqa: UP007
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any, NamedTuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
try:
|
||||||
|
import xgrammar as xgr
|
||||||
|
from xgrammar.base import _core as xgr_core
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.sampling_params import GuidedDecodingParams
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: passing batch size to max threads here
|
||||||
|
def get_local_xgrammar_guided_decoding_logits_processor(
|
||||||
|
guided_params: GuidedDecodingParams,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
max_threads: int = 8):
|
||||||
|
config = GrammarConfig.from_guided_params(guided_params=guided_params,
|
||||||
|
model_config=model_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_threads=max_threads)
|
||||||
|
return XGrammarLogitsProcessor(config)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerData(NamedTuple):
|
||||||
|
"""Immutable container for cached tokenizer data."""
|
||||||
|
encoded_vocab: list[str]
|
||||||
|
stop_token_ids: list[int] | None
|
||||||
|
backend_str: str
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerDataCache:
|
||||||
|
"""Cache manager for tokenizer data to avoid repeated processing."""
|
||||||
|
_cache: dict[int, TokenizerData] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tokenizer_data(cls,
|
||||||
|
tokenizer: PreTrainedTokenizer) -> TokenizerData:
|
||||||
|
tokenizer_hash = hash(tokenizer)
|
||||||
|
|
||||||
|
if tokenizer_hash not in cls._cache:
|
||||||
|
# Vendored from xgrammar logic since we cannot pickle the tokenizer
|
||||||
|
# https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501
|
||||||
|
try:
|
||||||
|
encoded_vocab = [
|
||||||
|
token for token, _ in sorted(tokenizer.get_vocab().items(),
|
||||||
|
key=lambda x: x[1])
|
||||||
|
]
|
||||||
|
except AttributeError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot get the vocabulary of the tokenizer "
|
||||||
|
f"{type(tokenizer)}. The tokenizer should have a "
|
||||||
|
"get_vocab method.") from e
|
||||||
|
|
||||||
|
stop_token_ids = None
|
||||||
|
backend_str = xgr.VocabType.RAW
|
||||||
|
if isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||||
|
backend_str = tokenizer.backend_tokenizer.to_str()
|
||||||
|
if stop_token_ids is None and hasattr(
|
||||||
|
tokenizer,
|
||||||
|
"eos_token_id") and tokenizer.eos_token_id is not None:
|
||||||
|
stop_token_ids = [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
cls._cache[tokenizer_hash] = TokenizerData(
|
||||||
|
encoded_vocab=encoded_vocab,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
backend_str=backend_str)
|
||||||
|
|
||||||
|
return cls._cache[tokenizer_hash]
|
||||||
|
|
||||||
|
|
||||||
|
class GrammarCompilerCache:
|
||||||
|
"""
|
||||||
|
Cache for GrammarCompiler instances based on tokenizer.
|
||||||
|
|
||||||
|
This cache reduces the overhead of creating new compiler instances when
|
||||||
|
using the same tokenizer configuration.
|
||||||
|
"""
|
||||||
|
_cache: dict[str, xgr.GrammarCompiler] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:
|
||||||
|
cache_key = str(config.tokenizer_hash)
|
||||||
|
|
||||||
|
if cache_key not in cls._cache:
|
||||||
|
assert config.encoded_vocab is not None
|
||||||
|
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
|
||||||
|
xgr_core.TokenizerInfo.from_huggingface(
|
||||||
|
config.encoded_vocab, config.backend_str,
|
||||||
|
config.vocab_size, config.stop_token_ids))
|
||||||
|
cls._cache[cache_key] = xgr.GrammarCompiler(
|
||||||
|
tokenizer_info, max_threads=config.max_threads)
|
||||||
|
|
||||||
|
return cls._cache[cache_key]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GrammarConfig:
|
||||||
|
"""Serializable configuration for grammar compilation"""
|
||||||
|
tokenizer_hash: int
|
||||||
|
vocab_size: int
|
||||||
|
json_str: str | None = None
|
||||||
|
grammar_str: str | None = None
|
||||||
|
json_object: bool | None = None
|
||||||
|
max_threads: int = 8
|
||||||
|
# Only populated if tokenizer_hash not in cache
|
||||||
|
encoded_vocab: list[str] | None = None
|
||||||
|
stop_token_ids: list[int] | None = None
|
||||||
|
backend_str: str | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_guided_params(cls,
|
||||||
|
guided_params: GuidedDecodingParams,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
max_threads: int = 8) -> GrammarConfig:
|
||||||
|
|
||||||
|
tokenizer_hash = hash(tokenizer)
|
||||||
|
# Only get tokenizer data if not already cached
|
||||||
|
if tokenizer_hash in TokenizerDataCache._cache:
|
||||||
|
encoded_vocab = None
|
||||||
|
stop_token_ids = None
|
||||||
|
backend_str = None
|
||||||
|
else:
|
||||||
|
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
|
||||||
|
encoded_vocab = tokenizer_data.encoded_vocab
|
||||||
|
stop_token_ids = tokenizer_data.stop_token_ids
|
||||||
|
backend_str = tokenizer_data.backend_str
|
||||||
|
|
||||||
|
if guided_params.json:
|
||||||
|
if not isinstance(guided_params.json, str):
|
||||||
|
json_str = json.dumps(guided_params.json)
|
||||||
|
else:
|
||||||
|
json_str = guided_params.json
|
||||||
|
return cls(json_str=json_str,
|
||||||
|
vocab_size=model_config.hf_config.vocab_size,
|
||||||
|
encoded_vocab=encoded_vocab,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
backend_str=backend_str,
|
||||||
|
tokenizer_hash=tokenizer_hash,
|
||||||
|
max_threads=max_threads)
|
||||||
|
elif guided_params.grammar:
|
||||||
|
return cls(grammar_str=guided_params.grammar,
|
||||||
|
vocab_size=model_config.hf_config.vocab_size,
|
||||||
|
encoded_vocab=encoded_vocab,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
backend_str=backend_str,
|
||||||
|
tokenizer_hash=tokenizer_hash,
|
||||||
|
max_threads=max_threads)
|
||||||
|
elif guided_params.json_object:
|
||||||
|
return cls(json_object=True,
|
||||||
|
vocab_size=model_config.hf_config.vocab_size,
|
||||||
|
encoded_vocab=encoded_vocab,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
backend_str=backend_str,
|
||||||
|
tokenizer_hash=tokenizer_hash,
|
||||||
|
max_threads=max_threads)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently only support JSON and EBNF grammar mode for xgrammar"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class XGrammarLogitsProcessor:
|
||||||
|
"""Wrapper class to support pickle protocol"""
|
||||||
|
config: GrammarConfig
|
||||||
|
|
||||||
|
ctx: xgr.CompiledGrammar | None = None
|
||||||
|
token_bitmask: torch.Tensor = None # type: ignore[assignment]
|
||||||
|
matchers: list[xgr.GrammarMatcher] = field(default_factory=list)
|
||||||
|
batch_size: int = field(default=1)
|
||||||
|
prefilled: bool = field(default=False)
|
||||||
|
|
||||||
|
def __getstate__(self) -> dict[str, Any]:
|
||||||
|
return {'config': self.config}
|
||||||
|
|
||||||
|
def __setstate__(self, state: dict[str, Any]):
|
||||||
|
self.config = state['config']
|
||||||
|
|
||||||
|
self.ctx = None
|
||||||
|
self.matchers = []
|
||||||
|
self.batch_size = 1
|
||||||
|
self.token_bitmask = None # type: ignore[assignment]
|
||||||
|
self.prefilled = False
|
||||||
|
|
||||||
|
def _ensure_ctx(self):
|
||||||
|
"""Lazily initialize the processor in the worker process"""
|
||||||
|
if self.ctx is None:
|
||||||
|
compiler = GrammarCompilerCache.get_compiler(self.config)
|
||||||
|
if self.config.json_str is not None:
|
||||||
|
self.ctx = compiler.compile_json_schema(self.config.json_str)
|
||||||
|
elif self.config.grammar_str is not None:
|
||||||
|
self.ctx = compiler.compile_grammar(self.config.grammar_str)
|
||||||
|
elif self.config.json_object:
|
||||||
|
self.ctx = compiler.compile_builtin_json_grammar()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid configuration for xgrammar logits processor")
|
||||||
|
|
||||||
|
def __call__(self, input_ids: list[int],
|
||||||
|
scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.ctx is None:
|
||||||
|
self._ensure_ctx()
|
||||||
|
|
||||||
|
if len(self.matchers) == 0:
|
||||||
|
self.matchers = [
|
||||||
|
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
|
||||||
|
]
|
||||||
|
self.token_bitmask = xgr.allocate_token_bitmask(
|
||||||
|
self.batch_size, self.config.vocab_size)
|
||||||
|
|
||||||
|
if not self.prefilled:
|
||||||
|
# Have not sampled a token yet
|
||||||
|
self.prefilled = True
|
||||||
|
else:
|
||||||
|
for i, matcher in enumerate(self.matchers):
|
||||||
|
if not matcher.is_terminated():
|
||||||
|
sampled_token = input_ids[-1]
|
||||||
|
assert self.matchers[i].accept_token(sampled_token)
|
||||||
|
|
||||||
|
for i, matcher in enumerate(self.matchers):
|
||||||
|
if not matcher.is_terminated():
|
||||||
|
# @ubospica: ideally, fill_next_token_bitmask should be
|
||||||
|
# parallelized with model decoding
|
||||||
|
# See https://github.com/vllm-project/vllm/pull/10785/files#r1864278303
|
||||||
|
matcher.fill_next_token_bitmask(self.token_bitmask, i)
|
||||||
|
|
||||||
|
# token_bitmask is a CPU tensor for use with accept_token and
|
||||||
|
# fill_next_token_bitmask so we move it to the device of scores
|
||||||
|
device_type = scores.device.type
|
||||||
|
if device_type != "cuda":
|
||||||
|
scores = scores.to("cpu")
|
||||||
|
xgr.apply_token_bitmask_inplace(scores,
|
||||||
|
self.token_bitmask.to(scores.device))
|
||||||
|
if device_type != "cuda":
|
||||||
|
scores = scores.to(device_type)
|
||||||
|
|
||||||
|
return scores
|
Loading…
Reference in New Issue