mirror of https://github.com/vllm-project/vllm
[Core] Tweaks to model runner/input builder developer APIs (#6712)
This commit is contained in:
parent
0e63494cf3
commit
5448f67635
|
@ -297,23 +297,26 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||
if is_profile_run:
|
||||
return
|
||||
|
||||
# Get the number of valid blocks based on sequence length.
|
||||
# If seq_len = 16, block_size = 16,
|
||||
# block_table_bound is 1 with 1 valid block.
|
||||
# If seq_len = 15, block_size = 16,
|
||||
# block_table_bound is 0 + 1 with 1 valid block.
|
||||
block_table_bound = seq_len // self.block_size + 1 \
|
||||
if seq_len % self.block_size != 0 \
|
||||
else seq_len // self.block_size
|
||||
block_table = block_tables[seq_id]
|
||||
self.paged_kv_indices.extend(block_table[:block_table_bound])
|
||||
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
|
||||
block_table_bound)
|
||||
self._update_paged_kv_tensors(block_table, seq_len)
|
||||
|
||||
last_page_len = seq_len % self.block_size
|
||||
if last_page_len == 0:
|
||||
last_page_len = self.block_size
|
||||
self.paged_kv_last_page_len.append(last_page_len)
|
||||
def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
|
||||
# Get the number of valid blocks based on sequence length.
|
||||
# If seq_len = 16, block_size = 16,
|
||||
# block_table_bound is 1 with 1 valid block.
|
||||
# If seq_len = 15, block_size = 16,
|
||||
# block_table_bound is 0 + 1 with 1 valid block.
|
||||
block_table_bound = seq_len // self.block_size + 1 \
|
||||
if seq_len % self.block_size != 0 \
|
||||
else seq_len // self.block_size
|
||||
self.paged_kv_indices.extend(block_table[:block_table_bound])
|
||||
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
|
||||
block_table_bound)
|
||||
|
||||
last_page_len = seq_len % self.block_size
|
||||
if last_page_len == 0:
|
||||
last_page_len = self.block_size
|
||||
self.paged_kv_last_page_len.append(last_page_len)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
|
|
|
@ -11,7 +11,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
|
|||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
|
||||
ModelInputForGPUBuilder)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -28,6 +29,7 @@ class EmbeddingModelRunner(
|
|||
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
|
||||
ModelInputForGPUWithPoolingMetadata)
|
||||
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -3,7 +3,7 @@ import gc
|
|||
import time
|
||||
import warnings
|
||||
import weakref
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
|
||||
Tuple, Type, TypeVar, Union)
|
||||
|
||||
|
@ -171,48 +171,83 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
|||
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
"""Build ModelInputForGPU from SequenceGroupMetadata."""
|
||||
|
||||
@dataclass
|
||||
# Note: ideally we would be using a dataclass(kw_only=True)
|
||||
# here, so that this can be subclassed easily,
|
||||
# but kw_only is not supported in python<3.10.
|
||||
class InterDataForSeqGroup:
|
||||
"""Intermediate data for the current sequence group."""
|
||||
# From sequence group metadata.
|
||||
request_id: str
|
||||
seq_ids: List[int]
|
||||
is_prompt: bool
|
||||
block_tables: Optional[Dict[int, List[int]]]
|
||||
computed_block_nums: List[int]
|
||||
n_seqs: int = 0
|
||||
|
||||
# Input tokens and positions.
|
||||
input_tokens: List[List[int]] = field(default_factory=list)
|
||||
input_positions: List[List[int]] = field(default_factory=list)
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
# From sequence group metadata.
|
||||
request_id: str,
|
||||
seq_ids: List[int],
|
||||
is_prompt: bool,
|
||||
block_tables: Optional[Dict[int, List[int]]],
|
||||
computed_block_nums: List[int],
|
||||
n_seqs: int = 0,
|
||||
|
||||
# The sequence length (may be capped to the sliding window).
|
||||
seq_lens: List[int] = field(default_factory=list)
|
||||
# The original sequence length (before applying sliding window).
|
||||
# This is used to compute slot mapping.
|
||||
orig_seq_lens: List[int] = field(default_factory=list)
|
||||
# The query length.
|
||||
query_lens: List[int] = field(default_factory=list)
|
||||
# The number of tokens that are already computed.
|
||||
context_lens: List[int] = field(default_factory=list)
|
||||
# The current sliding window block.
|
||||
curr_sliding_window_blocks: List[int] = field(default_factory=list)
|
||||
# Input tokens and positions.
|
||||
input_tokens: Optional[List[List[int]]] = None,
|
||||
input_positions: Optional[List[List[int]]] = None,
|
||||
|
||||
# LoRA inputs.
|
||||
lora_index_mapping: List[List[int]] = field(default_factory=list)
|
||||
lora_prompt_mapping: List[List[int]] = field(default_factory=list)
|
||||
lora_requests: Set[LoRARequest] = field(default_factory=set)
|
||||
# The sequence length (may be capped to the sliding window).
|
||||
seq_lens: Optional[List[int]] = None,
|
||||
# The original sequence length (before applying sliding window).
|
||||
# This is used to compute slot mapping.
|
||||
orig_seq_lens: Optional[List[int]] = None,
|
||||
# The query length.
|
||||
query_lens: Optional[List[int]] = None,
|
||||
# The number of tokens that are already computed.
|
||||
context_lens: Optional[List[int]] = None,
|
||||
# The current sliding window block.
|
||||
curr_sliding_window_blocks: Optional[List[int]] = None,
|
||||
|
||||
# Prompt adapter inputs.
|
||||
prompt_adapter_index_mapping: List[int] = field(default_factory=list)
|
||||
prompt_adapter_prompt_mapping: List[int] = field(default_factory=list)
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
# LoRA inputs.
|
||||
lora_index_mapping: Optional[List[List[int]]] = None,
|
||||
lora_prompt_mapping: Optional[List[List[int]]] = None,
|
||||
lora_requests: Optional[Set[LoRARequest]] = None,
|
||||
|
||||
# Multi-modal inputs.
|
||||
multi_modal_inputs: Optional[MultiModalInputs] = None
|
||||
# Prompt adapter inputs.
|
||||
prompt_adapter_index_mapping: Optional[List[int]] = None,
|
||||
prompt_adapter_prompt_mapping: Optional[List[int]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
|
||||
# Whether the prefix cache is hit (prefill only).
|
||||
prefix_cache_hit: bool = False
|
||||
# Multi-modal inputs.
|
||||
multi_modal_inputs: Optional[MultiModalInputs] = None,
|
||||
|
||||
# Whether the prefix cache is hit (prefill only).
|
||||
prefix_cache_hit: bool = False,
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.seq_ids = seq_ids
|
||||
self.is_prompt = is_prompt
|
||||
self.block_tables = block_tables
|
||||
self.computed_block_nums = computed_block_nums
|
||||
self.n_seqs = n_seqs
|
||||
self.input_tokens = input_tokens or []
|
||||
self.input_positions = input_positions or []
|
||||
self.seq_lens = seq_lens or []
|
||||
self.orig_seq_lens = orig_seq_lens or []
|
||||
self.query_lens = query_lens or []
|
||||
self.context_lens = context_lens or []
|
||||
self.curr_sliding_window_blocks = curr_sliding_window_blocks or []
|
||||
|
||||
self.lora_index_mapping = lora_index_mapping or []
|
||||
self.lora_prompt_mapping = lora_prompt_mapping or []
|
||||
self.lora_requests = lora_requests or set()
|
||||
|
||||
self.prompt_adapter_index_mapping = (prompt_adapter_index_mapping
|
||||
or [])
|
||||
self.prompt_adapter_prompt_mapping = (prompt_adapter_prompt_mapping
|
||||
or [])
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
|
||||
self.multi_modal_inputs = multi_modal_inputs
|
||||
self.prefix_cache_hit = prefix_cache_hit
|
||||
|
||||
self.__post_init__()
|
||||
|
||||
def __post_init__(self):
|
||||
self.n_seqs = len(self.seq_ids)
|
||||
|
@ -457,6 +492,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||
for per_seq_group_fn in self.per_seq_group_compute_fns:
|
||||
per_seq_group_fn(inter_data, seq_group_metadata)
|
||||
|
||||
def _use_captured_graph(self, batch_size: int,
|
||||
max_decode_seq_len: int) -> bool:
|
||||
return (self.decode_only and not self.runner.model_config.enforce_eager
|
||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
|
||||
|
||||
def build(self) -> ModelInputForGPU:
|
||||
"""Finalize the builder intermediate data and
|
||||
create on-device tensors.
|
||||
|
@ -491,10 +532,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||
}
|
||||
|
||||
batch_size = len(input_tokens)
|
||||
use_captured_graph = (
|
||||
self.decode_only and not self.runner.model_config.enforce_eager
|
||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
|
||||
use_captured_graph = self._use_captured_graph(batch_size,
|
||||
max_decode_seq_len)
|
||||
|
||||
# If cuda graph can be used, pad tensors accordingly.
|
||||
# See `capture_model` API for more details.
|
||||
|
@ -592,6 +631,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||
Helper class for shared methods between GPU model runners.
|
||||
"""
|
||||
_model_input_cls: Type[TModelInputForGPU]
|
||||
_builder_cls: Type[ModelInputForGPUBuilder]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -794,8 +834,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||
|
||||
If cuda graph is required, this API automatically pads inputs.
|
||||
"""
|
||||
builder = ModelInputForGPUBuilder(weakref.proxy(self),
|
||||
finished_requests_ids)
|
||||
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
builder.add_seq_group(seq_group_metadata)
|
||||
return builder.build() # type: ignore
|
||||
|
@ -1191,6 +1230,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||
"""
|
||||
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue