[Core] Tweaks to model runner/input builder developer APIs (#6712)

This commit is contained in:
Antoni Baum 2024-07-24 12:17:12 -07:00 committed by GitHub
parent 0e63494cf3
commit 5448f67635
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 102 additions and 57 deletions

View File

@ -297,23 +297,26 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if is_profile_run:
return
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
block_table = block_tables[seq_id]
self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound)
self._update_paged_kv_tensors(block_table, seq_len)
last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
self.paged_kv_last_page_len.append(last_page_len)
def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound)
last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
self.paged_kv_last_page_len.append(last_page_len)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):

View File

@ -11,7 +11,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
ModelInputForGPUBuilder)
logger = init_logger(__name__)
@ -28,6 +29,7 @@ class EmbeddingModelRunner(
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
ModelInputForGPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
def __init__(
self,

View File

@ -3,7 +3,7 @@ import gc
import time
import warnings
import weakref
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Type, TypeVar, Union)
@ -171,48 +171,83 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
"""Build ModelInputForGPU from SequenceGroupMetadata."""
@dataclass
# Note: ideally we would be using a dataclass(kw_only=True)
# here, so that this can be subclassed easily,
# but kw_only is not supported in python<3.10.
class InterDataForSeqGroup:
"""Intermediate data for the current sequence group."""
# From sequence group metadata.
request_id: str
seq_ids: List[int]
is_prompt: bool
block_tables: Optional[Dict[int, List[int]]]
computed_block_nums: List[int]
n_seqs: int = 0
# Input tokens and positions.
input_tokens: List[List[int]] = field(default_factory=list)
input_positions: List[List[int]] = field(default_factory=list)
def __init__(
self,
*,
# From sequence group metadata.
request_id: str,
seq_ids: List[int],
is_prompt: bool,
block_tables: Optional[Dict[int, List[int]]],
computed_block_nums: List[int],
n_seqs: int = 0,
# The sequence length (may be capped to the sliding window).
seq_lens: List[int] = field(default_factory=list)
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens: List[int] = field(default_factory=list)
# The query length.
query_lens: List[int] = field(default_factory=list)
# The number of tokens that are already computed.
context_lens: List[int] = field(default_factory=list)
# The current sliding window block.
curr_sliding_window_blocks: List[int] = field(default_factory=list)
# Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None,
input_positions: Optional[List[List[int]]] = None,
# LoRA inputs.
lora_index_mapping: List[List[int]] = field(default_factory=list)
lora_prompt_mapping: List[List[int]] = field(default_factory=list)
lora_requests: Set[LoRARequest] = field(default_factory=set)
# The sequence length (may be capped to the sliding window).
seq_lens: Optional[List[int]] = None,
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens: Optional[List[int]] = None,
# The query length.
query_lens: Optional[List[int]] = None,
# The number of tokens that are already computed.
context_lens: Optional[List[int]] = None,
# The current sliding window block.
curr_sliding_window_blocks: Optional[List[int]] = None,
# Prompt adapter inputs.
prompt_adapter_index_mapping: List[int] = field(default_factory=list)
prompt_adapter_prompt_mapping: List[int] = field(default_factory=list)
prompt_adapter_request: Optional[PromptAdapterRequest] = None
# LoRA inputs.
lora_index_mapping: Optional[List[List[int]]] = None,
lora_prompt_mapping: Optional[List[List[int]]] = None,
lora_requests: Optional[Set[LoRARequest]] = None,
# Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None
# Prompt adapter inputs.
prompt_adapter_index_mapping: Optional[List[int]] = None,
prompt_adapter_prompt_mapping: Optional[List[int]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False
# Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None,
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False,
):
self.request_id = request_id
self.seq_ids = seq_ids
self.is_prompt = is_prompt
self.block_tables = block_tables
self.computed_block_nums = computed_block_nums
self.n_seqs = n_seqs
self.input_tokens = input_tokens or []
self.input_positions = input_positions or []
self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or []
self.query_lens = query_lens or []
self.context_lens = context_lens or []
self.curr_sliding_window_blocks = curr_sliding_window_blocks or []
self.lora_index_mapping = lora_index_mapping or []
self.lora_prompt_mapping = lora_prompt_mapping or []
self.lora_requests = lora_requests or set()
self.prompt_adapter_index_mapping = (prompt_adapter_index_mapping
or [])
self.prompt_adapter_prompt_mapping = (prompt_adapter_prompt_mapping
or [])
self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_inputs = multi_modal_inputs
self.prefix_cache_hit = prefix_cache_hit
self.__post_init__()
def __post_init__(self):
self.n_seqs = len(self.seq_ids)
@ -457,6 +492,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for per_seq_group_fn in self.per_seq_group_compute_fns:
per_seq_group_fn(inter_data, seq_group_metadata)
def _use_captured_graph(self, batch_size: int,
max_decode_seq_len: int) -> bool:
return (self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
def build(self) -> ModelInputForGPU:
"""Finalize the builder intermediate data and
create on-device tensors.
@ -491,10 +532,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
}
batch_size = len(input_tokens)
use_captured_graph = (
self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
use_captured_graph = self._use_captured_graph(batch_size,
max_decode_seq_len)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
@ -592,6 +631,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
Helper class for shared methods between GPU model runners.
"""
_model_input_cls: Type[TModelInputForGPU]
_builder_cls: Type[ModelInputForGPUBuilder]
def __init__(
self,
@ -794,8 +834,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
If cuda graph is required, this API automatically pads inputs.
"""
builder = ModelInputForGPUBuilder(weakref.proxy(self),
finished_requests_ids)
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
return builder.build() # type: ignore
@ -1191,6 +1230,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"""
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
def make_model_input_from_broadcasted_tensor_dict(
self,