[Core][Performance] Add XGrammar support for guided decoding and set it as default (#10785)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Signed-off-by: mgoin <michael@neuralmagic.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Aaron Pham 2024-12-03 02:17:00 -05:00 committed by GitHub
parent 3257d449fa
commit 9323a3153b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 385 additions and 33 deletions

View File

@ -178,6 +178,7 @@ autodoc_mock_imports = [
"tensorizer", "tensorizer",
"pynvml", "pynvml",
"outlines", "outlines",
"xgrammar,"
"librosa", "librosa",
"soundfile", "soundfile",
"gguf", "gguf",

View File

@ -19,6 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.9, < 0.11 lm-format-enforcer >= 0.10.9, < 0.11
outlines >= 0.0.43, < 0.1 outlines >= 0.0.43, < 0.1
xgrammar
typing_extensions >= 4.10 typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs partial-json-parser # used for parsing partial JSON outputs

View File

@ -159,3 +159,30 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True, use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex)) guided_options_request=dict(guided_regex=sample_regex))
@pytest.mark.skip_global_cleanup
def test_guided_json_object(llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=100,
guided_decoding=GuidedDecodingParams(json_object=True))
outputs = llm.generate(
prompts=("Generate a JSON object describing a person with name "
"and age for John Smith who is 31 years old."),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)

View File

@ -36,7 +36,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"]) @pytest.mark.parametrize("backend",
["outlines", "lm-format-enforcer", "xgrammar"])
async def test_guided_logits_processor_black_box(backend: str, sample_regex, async def test_guided_logits_processor_black_box(backend: str, sample_regex,
sample_json_schema): sample_json_schema):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')

View File

@ -1789,15 +1789,15 @@ class PoolerConfig:
step_tag_id: Optional[int] = None step_tag_id: Optional[int] = None
""" """
If set, only the score corresponding to the ``step_tag_id`` in the If set, only the score corresponding to the ``step_tag_id`` in the
generated sentence should be returned. Otherwise, the scores for all tokens generated sentence should be returned. Otherwise, the scores for all tokens
are returned. are returned.
""" """
returned_token_ids: Optional[List[int]] = None returned_token_ids: Optional[List[int]] = None
""" """
A list of indices for the vocabulary dimensions to be extracted, A list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of ``good_token`` and ``bad_token`` in the such as the token IDs of ``good_token`` and ``bad_token`` in the
``math-shepherd-mistral-7b-prm`` model. ``math-shepherd-mistral-7b-prm`` model.
""" """
@ -2031,11 +2031,12 @@ def get_served_model_name(model: str,
class DecodingConfig: class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine""" """Dataclass which contains the decoding strategy of the engine"""
# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer' # Which guided decoding algo to use.
guided_decoding_backend: str = 'outlines' # 'outlines' / 'lm-format-enforcer' / 'xgrammar'
guided_decoding_backend: str = 'xgrammar'
def __post_init__(self): def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer'] valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
backend = self.guided_decoding_backend backend = self.guided_decoding_backend
if backend not in valid_guided_backends: if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend}," raise ValueError(f"Invalid guided_decoding_backend '{backend},"
@ -2222,7 +2223,7 @@ class CompilationConfig(BaseModel):
from Python, functions can also be passed directly via Python object from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
- custom inductor passes: see PassConfig for more details - custom inductor passes: see PassConfig for more details
Why we have different sizes for cudagraph and inductor: Why we have different sizes for cudagraph and inductor:
- cudagraph: a cudagraph captured for a specific size can only be used - cudagraph: a cudagraph captured for a specific size can only be used
for the same size. We need to capture all the sizes we want to use. for the same size. We need to capture all the sizes we want to use.

View File

@ -168,7 +168,7 @@ class EngineArgs:
scheduler_delay_factor: float = 0.0 scheduler_delay_factor: float = 0.0
enable_chunked_prefill: Optional[bool] = None enable_chunked_prefill: Optional[bool] = None
guided_decoding_backend: str = 'outlines' guided_decoding_backend: str = 'xgrammar'
# Speculative decoding configuration. # Speculative decoding configuration.
speculative_model: Optional[str] = None speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None speculative_model_quantization: Optional[str] = None
@ -364,11 +364,12 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
'--guided-decoding-backend', '--guided-decoding-backend',
type=str, type=str,
default='outlines', default='xgrammar',
choices=['outlines', 'lm-format-enforcer'], choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
help='Which engine will be used for guided decoding' help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support ' ' (JSON schema / regex etc) by default. Currently support '
'https://github.com/outlines-dev/outlines and ' 'https://github.com/outlines-dev/outlines,'
'https://github.com/mlc-ai/xgrammar, and '
'https://github.com/noamgat/lm-format-enforcer.' 'https://github.com/noamgat/lm-format-enforcer.'
' Can be overridden per request via guided_decoding_backend' ' Can be overridden per request via guided_decoding_backend'
' parameter.') ' parameter.')

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import copy
import time import time
import weakref import weakref
from functools import partial from functools import partial
@ -507,7 +508,8 @@ class _AsyncLLMEngine(LLMEngine):
sampling_params=params, sampling_params=params,
tokenizer=await self.get_tokenizer_async(lora_request), tokenizer=await self.get_tokenizer_async(lora_request),
default_guided_backend=self.decoding_config. default_guided_backend=self.decoding_config.
guided_decoding_backend) guided_decoding_backend,
model_config=self.model_config)
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
@ -528,22 +530,30 @@ class _AsyncLLMEngine(LLMEngine):
async def build_guided_decoding_logits_processor_async( async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer, sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str) -> SamplingParams: default_guided_backend: str,
model_config: ModelConfig) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding, """Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the those fields and adds the constructed logits processors to the
logits_processors field. Modifies sampling params in-place and returns logits_processors field. Modifies sampling params in-place and returns
the modified sampling params.""" the modified sampling params."""
if (guided_decoding := sampling_params.guided_decoding) is None: if sampling_params.guided_decoding is None:
return sampling_params return sampling_params
# Defensively copy sampling params since guided decoding logits
# processors can have different state for each request
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding
logger.debug("Building guided decoding logits processor. " logger.debug("Building guided decoding logits processor. "
"Params: %s", guided_decoding) "Params: %s", guided_decoding)
guided_decoding.backend = guided_decoding.backend or default_guided_backend guided_decoding.backend = guided_decoding.backend or default_guided_backend
processor = await get_guided_decoding_logits_processor( processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer) guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=model_config)
if processor: if processor:
if sampling_params.logits_processors is None: if sampling_params.logits_processors is None:

View File

@ -1,3 +1,4 @@
import copy
import time import time
from collections import Counter as collectionsCounter from collections import Counter as collectionsCounter
from collections import deque from collections import deque
@ -1024,9 +1025,9 @@ class LLMEngine:
This function updates num_computed_tokens for prompt sequences This function updates num_computed_tokens for prompt sequences
when Multi-Step is enabled. when Multi-Step is enabled.
seq_group: SequenceGroup to update the num_computed_tokens for. seq_group: SequenceGroup to update the num_computed_tokens for.
seq_group_meta: Metadata of the given SequenceGroup. seq_group_meta: Metadata of the given SequenceGroup.
is_first_step_output: Optional[bool] - is_first_step_output: Optional[bool] -
When available, is_first_step_output indicates if the appended When available, is_first_step_output indicates if the appended
output token is the output of the first-step in multi-step. output token is the output of the first-step in multi-step.
A value of None indicates that outputs from all steps in A value of None indicates that outputs from all steps in
@ -2036,7 +2037,11 @@ class LLMEngine:
logits_processors = [] logits_processors = []
if (guided_decoding := sampling_params.guided_decoding) is not None: if sampling_params.guided_decoding is not None:
# Defensively copy sampling params since guided decoding logits
# processors can have different state for each request
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding
logger.debug( logger.debug(
"Building guided decoding logits processor in " "Building guided decoding logits processor in "
@ -2047,7 +2052,9 @@ class LLMEngine:
self.decoding_config.guided_decoding_backend self.decoding_config.guided_decoding_backend
processor = get_local_guided_decoding_logits_processor( processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer) guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=self.model_config)
if processor: if processor:
logits_processors.append(processor) logits_processors.append(processor)

View File

@ -474,8 +474,8 @@ class MQLLMEngineClient(EngineClient):
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use prompt_adapter_request: Prompt Adapter request to use
for generation, if any. for generation, if any.
priority: Priority of the request (lower means earlier handling). priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the Any priority other than 0 will lead to an error if the
scheduling policy is not "priority". scheduling policy is not "priority".
""" """
if inputs is not None: if inputs is not None:
@ -589,6 +589,7 @@ class MQLLMEngineClient(EngineClient):
default_guided_backend=(self.decoding_config.guided_decoding_backend default_guided_backend=(self.decoding_config.guided_decoding_backend
if self.decoding_config if self.decoding_config
else DecodingConfig.guided_decoding_backend), else DecodingConfig.guided_decoding_backend),
model_config=self.model_config
) )
# 1) Create output queue for this requests. # 1) Create output queue for this requests.

View File

@ -1,14 +1,54 @@
from typing import Optional from __future__ import annotations
from vllm.logits_process import LogitsProcessor from typing import TYPE_CHECKING
from vllm.sampling_params import GuidedDecodingParams
from vllm.logger import init_logger
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__)
def maybe_backend_fallback(
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if (guided_params.backend == "lm-format-enforcer"
and guided_params.grammar is not None):
logger.warning(
"lm-format-enforcer does not support grammar guided decoding. "
"Falling back to use xgrammar instead.")
guided_params.backend = "xgrammar"
if guided_params.backend == "xgrammar":
# xgrammar doesn't support regex or choice, fallback to outlines
if guided_params.regex is not None or guided_params.choice is not None:
logger.warning(
"xgrammar only supports json or grammar guided decoding. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
# xgrammar only supports EBNF grammars and uses the GBNF format
# https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
elif (guided_params.grammar is not None
and "::=" not in guided_params.grammar):
logger.warning("xgrammar only supports EBNF grammars. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"
return guided_params
async def get_guided_decoding_logits_processor( async def get_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
tokenizer) -> Optional[LogitsProcessor]: model_config: ModelConfig) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead # CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar: if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor) get_outlines_guided_decoding_logits_processor)
@ -19,17 +59,23 @@ async def get_guided_decoding_logits_processor(
get_local_lm_format_enforcer_guided_decoding_logits_processor) get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor( return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer) guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config)
raise ValueError( raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. " f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'") "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
def get_local_guided_decoding_logits_processor( def get_local_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
tokenizer) -> Optional[LogitsProcessor]: model_config: ModelConfig) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead # CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar: if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor) get_local_outlines_guided_decoding_logits_processor)
@ -40,7 +86,12 @@ def get_local_guided_decoding_logits_processor(
get_local_lm_format_enforcer_guided_decoding_logits_processor) get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor( return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer) guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config)
raise ValueError( raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. " f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'") "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")

View File

@ -0,0 +1,251 @@
# noqa: UP007
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, NamedTuple
import torch
from transformers import PreTrainedTokenizerFast
try:
import xgrammar as xgr
from xgrammar.base import _core as xgr_core
except ImportError:
pass
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.sampling_params import GuidedDecodingParams
# TODO: passing batch size to max threads here
def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config,
tokenizer=tokenizer,
max_threads=max_threads)
return XGrammarLogitsProcessor(config)
class TokenizerData(NamedTuple):
"""Immutable container for cached tokenizer data."""
encoded_vocab: list[str]
stop_token_ids: list[int] | None
backend_str: str
class TokenizerDataCache:
"""Cache manager for tokenizer data to avoid repeated processing."""
_cache: dict[int, TokenizerData] = {}
@classmethod
def get_tokenizer_data(cls,
tokenizer: PreTrainedTokenizer) -> TokenizerData:
tokenizer_hash = hash(tokenizer)
if tokenizer_hash not in cls._cache:
# Vendored from xgrammar logic since we cannot pickle the tokenizer
# https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501
try:
encoded_vocab = [
token for token, _ in sorted(tokenizer.get_vocab().items(),
key=lambda x: x[1])
]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e
stop_token_ids = None
backend_str = xgr.VocabType.RAW
if isinstance(tokenizer, PreTrainedTokenizerFast):
backend_str = tokenizer.backend_tokenizer.to_str()
if stop_token_ids is None and hasattr(
tokenizer,
"eos_token_id") and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
cls._cache[tokenizer_hash] = TokenizerData(
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str)
return cls._cache[tokenizer_hash]
class GrammarCompilerCache:
"""
Cache for GrammarCompiler instances based on tokenizer.
This cache reduces the overhead of creating new compiler instances when
using the same tokenizer configuration.
"""
_cache: dict[str, xgr.GrammarCompiler] = {}
@classmethod
def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:
cache_key = str(config.tokenizer_hash)
if cache_key not in cls._cache:
assert config.encoded_vocab is not None
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
xgr_core.TokenizerInfo.from_huggingface(
config.encoded_vocab, config.backend_str,
config.vocab_size, config.stop_token_ids))
cls._cache[cache_key] = xgr.GrammarCompiler(
tokenizer_info, max_threads=config.max_threads)
return cls._cache[cache_key]
@dataclass
class GrammarConfig:
"""Serializable configuration for grammar compilation"""
tokenizer_hash: int
vocab_size: int
json_str: str | None = None
grammar_str: str | None = None
json_object: bool | None = None
max_threads: int = 8
# Only populated if tokenizer_hash not in cache
encoded_vocab: list[str] | None = None
stop_token_ids: list[int] | None = None
backend_str: str | None = None
@classmethod
def from_guided_params(cls,
guided_params: GuidedDecodingParams,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
max_threads: int = 8) -> GrammarConfig:
tokenizer_hash = hash(tokenizer)
# Only get tokenizer data if not already cached
if tokenizer_hash in TokenizerDataCache._cache:
encoded_vocab = None
stop_token_ids = None
backend_str = None
else:
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
encoded_vocab = tokenizer_data.encoded_vocab
stop_token_ids = tokenizer_data.stop_token_ids
backend_str = tokenizer_data.backend_str
if guided_params.json:
if not isinstance(guided_params.json, str):
json_str = json.dumps(guided_params.json)
else:
json_str = guided_params.json
return cls(json_str=json_str,
vocab_size=model_config.hf_config.vocab_size,
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads)
elif guided_params.grammar:
return cls(grammar_str=guided_params.grammar,
vocab_size=model_config.hf_config.vocab_size,
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads)
elif guided_params.json_object:
return cls(json_object=True,
vocab_size=model_config.hf_config.vocab_size,
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads)
else:
raise ValueError(
"Currently only support JSON and EBNF grammar mode for xgrammar"
)
@dataclass
class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol"""
config: GrammarConfig
ctx: xgr.CompiledGrammar | None = None
token_bitmask: torch.Tensor = None # type: ignore[assignment]
matchers: list[xgr.GrammarMatcher] = field(default_factory=list)
batch_size: int = field(default=1)
prefilled: bool = field(default=False)
def __getstate__(self) -> dict[str, Any]:
return {'config': self.config}
def __setstate__(self, state: dict[str, Any]):
self.config = state['config']
self.ctx = None
self.matchers = []
self.batch_size = 1
self.token_bitmask = None # type: ignore[assignment]
self.prefilled = False
def _ensure_ctx(self):
"""Lazily initialize the processor in the worker process"""
if self.ctx is None:
compiler = GrammarCompilerCache.get_compiler(self.config)
if self.config.json_str is not None:
self.ctx = compiler.compile_json_schema(self.config.json_str)
elif self.config.grammar_str is not None:
self.ctx = compiler.compile_grammar(self.config.grammar_str)
elif self.config.json_object:
self.ctx = compiler.compile_builtin_json_grammar()
else:
raise ValueError(
"Invalid configuration for xgrammar logits processor")
def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor:
if self.ctx is None:
self._ensure_ctx()
if len(self.matchers) == 0:
self.matchers = [
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
]
self.token_bitmask = xgr.allocate_token_bitmask(
self.batch_size, self.config.vocab_size)
if not self.prefilled:
# Have not sampled a token yet
self.prefilled = True
else:
for i, matcher in enumerate(self.matchers):
if not matcher.is_terminated():
sampled_token = input_ids[-1]
assert self.matchers[i].accept_token(sampled_token)
for i, matcher in enumerate(self.matchers):
if not matcher.is_terminated():
# @ubospica: ideally, fill_next_token_bitmask should be
# parallelized with model decoding
# See https://github.com/vllm-project/vllm/pull/10785/files#r1864278303
matcher.fill_next_token_bitmask(self.token_bitmask, i)
# token_bitmask is a CPU tensor for use with accept_token and
# fill_next_token_bitmask so we move it to the device of scores
device_type = scores.device.type
if device_type != "cuda":
scores = scores.to("cpu")
xgr.apply_token_bitmask_inplace(scores,
self.token_bitmask.to(scores.device))
if device_type != "cuda":
scores = scores.to(device_type)
return scores