[Core][Performance] Add XGrammar support for guided decoding and set it as default (#10785)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Signed-off-by: mgoin <michael@neuralmagic.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Aaron Pham 2024-12-03 02:17:00 -05:00 committed by GitHub
parent 3257d449fa
commit 9323a3153b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 385 additions and 33 deletions

View File

@ -178,6 +178,7 @@ autodoc_mock_imports = [
"tensorizer",
"pynvml",
"outlines",
"xgrammar,"
"librosa",
"soundfile",
"gguf",

View File

@ -19,6 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.9, < 0.11
outlines >= 0.0.43, < 0.1
xgrammar
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs

View File

@ -159,3 +159,30 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
@pytest.mark.skip_global_cleanup
def test_guided_json_object(llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=100,
guided_decoding=GuidedDecodingParams(json_object=True))
outputs = llm.generate(
prompts=("Generate a JSON object describing a person with name "
"and age for John Smith who is 31 years old."),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)

View File

@ -36,7 +36,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
@pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
@pytest.mark.parametrize("backend",
["outlines", "lm-format-enforcer", "xgrammar"])
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
sample_json_schema):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')

View File

@ -2031,11 +2031,12 @@ def get_served_model_name(model: str,
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine"""
# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
guided_decoding_backend: str = 'outlines'
# Which guided decoding algo to use.
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
guided_decoding_backend: str = 'xgrammar'
def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer']
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
backend = self.guided_decoding_backend
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend},"

View File

@ -168,7 +168,7 @@ class EngineArgs:
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: Optional[bool] = None
guided_decoding_backend: str = 'outlines'
guided_decoding_backend: str = 'xgrammar'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None
@ -364,11 +364,12 @@ class EngineArgs:
parser.add_argument(
'--guided-decoding-backend',
type=str,
default='outlines',
choices=['outlines', 'lm-format-enforcer'],
default='xgrammar',
choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support '
'https://github.com/outlines-dev/outlines and '
'https://github.com/outlines-dev/outlines,'
'https://github.com/mlc-ai/xgrammar, and '
'https://github.com/noamgat/lm-format-enforcer.'
' Can be overridden per request via guided_decoding_backend'
' parameter.')

View File

@ -1,4 +1,5 @@
import asyncio
import copy
import time
import weakref
from functools import partial
@ -507,7 +508,8 @@ class _AsyncLLMEngine(LLMEngine):
sampling_params=params,
tokenizer=await self.get_tokenizer_async(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend)
guided_decoding_backend,
model_config=self.model_config)
self._add_processed_request(
request_id=request_id,
@ -528,22 +530,30 @@ class _AsyncLLMEngine(LLMEngine):
async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str) -> SamplingParams:
default_guided_backend: str,
model_config: ModelConfig) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Modifies sampling params in-place and returns
the modified sampling params."""
if (guided_decoding := sampling_params.guided_decoding) is None:
if sampling_params.guided_decoding is None:
return sampling_params
# Defensively copy sampling params since guided decoding logits
# processors can have different state for each request
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding
logger.debug("Building guided decoding logits processor. "
"Params: %s", guided_decoding)
guided_decoding.backend = guided_decoding.backend or default_guided_backend
processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=model_config)
if processor:
if sampling_params.logits_processors is None:

View File

@ -1,3 +1,4 @@
import copy
import time
from collections import Counter as collectionsCounter
from collections import deque
@ -2036,7 +2037,11 @@ class LLMEngine:
logits_processors = []
if (guided_decoding := sampling_params.guided_decoding) is not None:
if sampling_params.guided_decoding is not None:
# Defensively copy sampling params since guided decoding logits
# processors can have different state for each request
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding
logger.debug(
"Building guided decoding logits processor in "
@ -2047,7 +2052,9 @@ class LLMEngine:
self.decoding_config.guided_decoding_backend
processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=self.model_config)
if processor:
logits_processors.append(processor)

View File

@ -589,6 +589,7 @@ class MQLLMEngineClient(EngineClient):
default_guided_backend=(self.decoding_config.guided_decoding_backend
if self.decoding_config
else DecodingConfig.guided_decoding_backend),
model_config=self.model_config
)
# 1) Create output queue for this requests.

View File

@ -1,14 +1,54 @@
from typing import Optional
from __future__ import annotations
from typing import TYPE_CHECKING
from vllm.logger import init_logger
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__)
def maybe_backend_fallback(
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if (guided_params.backend == "lm-format-enforcer"
and guided_params.grammar is not None):
logger.warning(
"lm-format-enforcer does not support grammar guided decoding. "
"Falling back to use xgrammar instead.")
guided_params.backend = "xgrammar"
if guided_params.backend == "xgrammar":
# xgrammar doesn't support regex or choice, fallback to outlines
if guided_params.regex is not None or guided_params.choice is not None:
logger.warning(
"xgrammar only supports json or grammar guided decoding. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
# xgrammar only supports EBNF grammars and uses the GBNF format
# https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
elif (guided_params.grammar is not None
and "::=" not in guided_params.grammar):
logger.warning("xgrammar only supports EBNF grammars. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
return guided_params
async def get_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
@ -19,17 +59,23 @@ async def get_guided_decoding_logits_processor(
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
def get_local_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
@ -40,7 +86,12 @@ def get_local_guided_decoding_logits_processor(
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")

View File

@ -0,0 +1,251 @@
# noqa: UP007
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, NamedTuple
import torch
from transformers import PreTrainedTokenizerFast
try:
import xgrammar as xgr
from xgrammar.base import _core as xgr_core
except ImportError:
pass
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.sampling_params import GuidedDecodingParams
# TODO: passing batch size to max threads here
def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config,
tokenizer=tokenizer,
max_threads=max_threads)
return XGrammarLogitsProcessor(config)
class TokenizerData(NamedTuple):
"""Immutable container for cached tokenizer data."""
encoded_vocab: list[str]
stop_token_ids: list[int] | None
backend_str: str
class TokenizerDataCache:
"""Cache manager for tokenizer data to avoid repeated processing."""
_cache: dict[int, TokenizerData] = {}
@classmethod
def get_tokenizer_data(cls,
tokenizer: PreTrainedTokenizer) -> TokenizerData:
tokenizer_hash = hash(tokenizer)
if tokenizer_hash not in cls._cache:
# Vendored from xgrammar logic since we cannot pickle the tokenizer
# https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501
try:
encoded_vocab = [
token for token, _ in sorted(tokenizer.get_vocab().items(),
key=lambda x: x[1])
]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e
stop_token_ids = None
backend_str = xgr.VocabType.RAW
if isinstance(tokenizer, PreTrainedTokenizerFast):
backend_str = tokenizer.backend_tokenizer.to_str()
if stop_token_ids is None and hasattr(
tokenizer,
"eos_token_id") and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
cls._cache[tokenizer_hash] = TokenizerData(
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str)
return cls._cache[tokenizer_hash]
class GrammarCompilerCache:
"""
Cache for GrammarCompiler instances based on tokenizer.
This cache reduces the overhead of creating new compiler instances when
using the same tokenizer configuration.
"""
_cache: dict[str, xgr.GrammarCompiler] = {}
@classmethod
def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:
cache_key = str(config.tokenizer_hash)
if cache_key not in cls._cache:
assert config.encoded_vocab is not None
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
xgr_core.TokenizerInfo.from_huggingface(
config.encoded_vocab, config.backend_str,
config.vocab_size, config.stop_token_ids))
cls._cache[cache_key] = xgr.GrammarCompiler(
tokenizer_info, max_threads=config.max_threads)
return cls._cache[cache_key]
@dataclass
class GrammarConfig:
"""Serializable configuration for grammar compilation"""
tokenizer_hash: int
vocab_size: int
json_str: str | None = None
grammar_str: str | None = None
json_object: bool | None = None
max_threads: int = 8
# Only populated if tokenizer_hash not in cache
encoded_vocab: list[str] | None = None
stop_token_ids: list[int] | None = None
backend_str: str | None = None
@classmethod
def from_guided_params(cls,
guided_params: GuidedDecodingParams,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
max_threads: int = 8) -> GrammarConfig:
tokenizer_hash = hash(tokenizer)
# Only get tokenizer data if not already cached
if tokenizer_hash in TokenizerDataCache._cache:
encoded_vocab = None
stop_token_ids = None
backend_str = None
else:
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
encoded_vocab = tokenizer_data.encoded_vocab
stop_token_ids = tokenizer_data.stop_token_ids
backend_str = tokenizer_data.backend_str
if guided_params.json:
if not isinstance(guided_params.json, str):
json_str = json.dumps(guided_params.json)
else:
json_str = guided_params.json
return cls(json_str=json_str,
vocab_size=model_config.hf_config.vocab_size,
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads)
elif guided_params.grammar:
return cls(grammar_str=guided_params.grammar,
vocab_size=model_config.hf_config.vocab_size,
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads)
elif guided_params.json_object:
return cls(json_object=True,
vocab_size=model_config.hf_config.vocab_size,
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads)
else:
raise ValueError(
"Currently only support JSON and EBNF grammar mode for xgrammar"
)
@dataclass
class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol"""
config: GrammarConfig
ctx: xgr.CompiledGrammar | None = None
token_bitmask: torch.Tensor = None # type: ignore[assignment]
matchers: list[xgr.GrammarMatcher] = field(default_factory=list)
batch_size: int = field(default=1)
prefilled: bool = field(default=False)
def __getstate__(self) -> dict[str, Any]:
return {'config': self.config}
def __setstate__(self, state: dict[str, Any]):
self.config = state['config']
self.ctx = None
self.matchers = []
self.batch_size = 1
self.token_bitmask = None # type: ignore[assignment]
self.prefilled = False
def _ensure_ctx(self):
"""Lazily initialize the processor in the worker process"""
if self.ctx is None:
compiler = GrammarCompilerCache.get_compiler(self.config)
if self.config.json_str is not None:
self.ctx = compiler.compile_json_schema(self.config.json_str)
elif self.config.grammar_str is not None:
self.ctx = compiler.compile_grammar(self.config.grammar_str)
elif self.config.json_object:
self.ctx = compiler.compile_builtin_json_grammar()
else:
raise ValueError(
"Invalid configuration for xgrammar logits processor")
def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor:
if self.ctx is None:
self._ensure_ctx()
if len(self.matchers) == 0:
self.matchers = [
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
]
self.token_bitmask = xgr.allocate_token_bitmask(
self.batch_size, self.config.vocab_size)
if not self.prefilled:
# Have not sampled a token yet
self.prefilled = True
else:
for i, matcher in enumerate(self.matchers):
if not matcher.is_terminated():
sampled_token = input_ids[-1]
assert self.matchers[i].accept_token(sampled_token)
for i, matcher in enumerate(self.matchers):
if not matcher.is_terminated():
# @ubospica: ideally, fill_next_token_bitmask should be
# parallelized with model decoding
# See https://github.com/vllm-project/vllm/pull/10785/files#r1864278303
matcher.fill_next_token_bitmask(self.token_bitmask, i)
# token_bitmask is a CPU tensor for use with accept_token and
# fill_next_token_bitmask so we move it to the device of scores
device_type = scores.device.type
if device_type != "cuda":
scores = scores.to("cpu")
xgr.apply_token_bitmask_inplace(scores,
self.token_bitmask.to(scores.device))
if device_type != "cuda":
scores = scores.to(device_type)
return scores