mirror of https://github.com/vllm-project/vllm
[Hardware][Intel GPU] refactor xpu_model_runner for tp (#7712)
This commit is contained in:
parent
c01a6cb231
commit
fc5ebbd1d3
|
@ -1,386 +1,37 @@
|
|||
import asyncio
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import islice, repeat
|
||||
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
|
||||
Tuple, Union)
|
||||
from typing import List, Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, PromptAdapterConfig,
|
||||
SchedulerConfig, SpeculativeConfig)
|
||||
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
||||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync
|
||||
from vllm.executor.xpu_executor import XPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
|
||||
if ray is not None:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
from vllm.utils import get_vllm_instance_id, make_async
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# If the env var is set, it uses the Ray's compiled DAG API
|
||||
# which optimizes the control plane overhead.
|
||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||
USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
|
||||
|
||||
class RayXPUExecutor(RayGPUExecutor, XPUExecutor):
|
||||
|
||||
class RayXPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
uses_ray: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
assert device_config.device_type == "xpu"
|
||||
assert (not speculative_config
|
||||
), "Speculative decoding not yet supported for XPU backend"
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.load_config = load_config
|
||||
self.lora_config = lora_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
||||
# Disable Ray usage stats collection.
|
||||
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
||||
if ray_usage != "1":
|
||||
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
self._init_workers_ray(placement_group)
|
||||
|
||||
self.forward_dag = None
|
||||
if USE_RAY_COMPILED_DAG:
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
# This is non-None when the execute model loop is running
|
||||
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
|
||||
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
|
||||
# Updated by implementations that require additional args to be passed
|
||||
# to the _run_workers execute_model call
|
||||
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
pass
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks.
|
||||
|
||||
This invokes `determine_num_available_blocks` on each worker and takes
|
||||
the min of the results, guaranteeing that the selected cache sizes are
|
||||
compatible with all workers.
|
||||
|
||||
Returns:
|
||||
- Tuple[num_gpu_blocks, num_cpu_blocks]
|
||||
"""
|
||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||
num_blocks = self._run_workers("determine_num_available_blocks", )
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
# number of blocks across all workers to make sure all the memory
|
||||
# operators can be applied to all workers.
|
||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
|
||||
return dict(
|
||||
worker_module_name="vllm.worker.xpu_worker",
|
||||
worker_class_name="XPUWorker",
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
if self.parallel_config.tensor_parallel_size == 1:
|
||||
# For single GPU case, we use a ray worker with constrained memory.
|
||||
num_gpus = self.cache_config.gpu_memory_utilization
|
||||
else:
|
||||
# Otherwise, the ray workers are allocated with a full GPU.
|
||||
num_gpus = 1
|
||||
|
||||
# The driver dummy worker does not actually use any resources.
|
||||
# It holds the resource for the driver worker.
|
||||
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
|
||||
# The remaining workers are the actual ray actors.
|
||||
self.workers: List[RayWorkerWrapper] = []
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
|
||||
|
||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
if self.driver_dummy_worker is None:
|
||||
raise ValueError(
|
||||
"Ray does not allocate any GPUs on the driver node. Consider "
|
||||
"adjusting the Ray placement group or running the driver on a "
|
||||
"GPU node.")
|
||||
|
||||
def _get_env_vars_to_be_updated(self):
|
||||
# Get the set of GPU IDs used on each node.
|
||||
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
|
||||
use_dummy_driver=True)
|
||||
|
||||
node_workers = defaultdict(list)
|
||||
node_gpus = defaultdict(list)
|
||||
VLLM_INSTANCE_ID = get_vllm_instance_id()
|
||||
|
||||
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
|
||||
node_workers[node_id].append(i)
|
||||
node_gpus[node_id].extend(gpu_ids)
|
||||
for node_id, gpu_ids in node_gpus.items():
|
||||
node_gpus[node_id] = sorted(gpu_ids)
|
||||
|
||||
# TODO: add env var for xpu
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port())
|
||||
|
||||
def collect_arg_helper_func(**kwargs):
|
||||
# avoid writing `{"name": value}` manually
|
||||
return kwargs
|
||||
|
||||
init_worker_all_kwargs = []
|
||||
|
||||
# Initialize the actual workers inside worker wrapper.
|
||||
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
|
||||
local_rank = node_workers[node_id].index(rank)
|
||||
init_worker_all_kwargs.append(
|
||||
collect_arg_helper_func(
|
||||
model_config=self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
load_config=self.load_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
is_driver_worker=rank == 0,
|
||||
))
|
||||
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
|
||||
|
||||
self._run_workers("init_device")
|
||||
self._run_workers(
|
||||
"load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers,
|
||||
)
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache in all workers.
|
||||
"""
|
||||
|
||||
# NOTE: We log here to avoid multiple logs when number of workers is
|
||||
# greater than one. We could log in the engine, but not all executors
|
||||
# have GPUs.
|
||||
logger.info("# GPU blocks: %d, "
|
||||
"# CPU blocks: %d", num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
self._run_workers("initialize_cache",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks)
|
||||
|
||||
def _driver_execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
"""Run execute_model in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
return self.driver_worker.execute_method("execute_model",
|
||||
execute_model_req)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||
return self._run_workers(
|
||||
"add_lora",
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self._run_workers(
|
||||
"remove_lora",
|
||||
lora_id=lora_id,
|
||||
)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self._run_workers("list_loras")
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
async_run_remote_workers_only: bool = False,
|
||||
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
||||
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||
use_dummy_driver: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers. Can be used in the following
|
||||
ways:
|
||||
|
||||
- args/kwargs: All workers share the same args/kwargs
|
||||
- args/kwargs and driver_args/driver_kwargs: Driver worker has
|
||||
different args
|
||||
- all_args/all_kwargs: args/kwargs for each worker are specified
|
||||
individually
|
||||
"""
|
||||
|
||||
if max_concurrent_workers:
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
count = len(self.workers)
|
||||
all_worker_args = repeat(args, count) if all_args is None \
|
||||
else islice(all_args, 1, None)
|
||||
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
|
||||
else islice(all_kwargs, 1, None)
|
||||
|
||||
# Start the ray workers first.
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
|
||||
for (worker, worker_args, worker_kwargs
|
||||
) in zip(self.workers, all_worker_args, all_worker_kwargs)
|
||||
]
|
||||
|
||||
if async_run_remote_workers_only:
|
||||
# Just return futures
|
||||
return ray_worker_outputs
|
||||
|
||||
driver_worker_output = []
|
||||
driver_args = args if all_args is None else all_args[0]
|
||||
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||
# Start the driver worker after all the ray workers.
|
||||
if not use_dummy_driver:
|
||||
driver_worker_output = self.driver_worker.execute_method(
|
||||
method, *driver_args, **driver_kwargs)
|
||||
else:
|
||||
assert self.driver_dummy_worker is not None
|
||||
driver_worker_output = ray.get(
|
||||
self.driver_dummy_worker.execute_method.remote(
|
||||
method, *driver_args, **driver_kwargs))
|
||||
# Get the results of the ray workers.
|
||||
if self.workers:
|
||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||
|
||||
return driver_worker_output + ray_worker_outputs
|
||||
|
||||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||
"""Wait for futures returned from _run_workers() with
|
||||
async_run_remote_workers_only to complete."""
|
||||
ray.get(parallel_worker_tasks)
|
||||
|
||||
def _compiled_ray_dag(self, enable_asyncio: bool):
|
||||
import pkg_resources
|
||||
from packaging import version
|
||||
|
||||
required_version = version.parse("2.32")
|
||||
current_version = version.parse(
|
||||
pkg_resources.get_distribution("ray").version)
|
||||
if current_version < required_version:
|
||||
raise ValueError(f"Ray version {required_version} or greater is "
|
||||
f"required, but found {current_version}")
|
||||
|
||||
from ray.dag import InputNode, MultiOutputNode
|
||||
assert self.parallel_config.use_ray
|
||||
|
||||
# Right now, compiled DAG requires at least 1 arg. We send
|
||||
# a dummy value for now. It will be fixed soon.
|
||||
with InputNode() as input_data:
|
||||
forward_dag = MultiOutputNode([
|
||||
worker.execute_model_compiled_dag_remote.
|
||||
bind( # type: ignore[attr-defined]
|
||||
input_data) for worker in self.workers
|
||||
])
|
||||
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
|
||||
|
||||
def check_health(self) -> None:
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
self._check_if_any_actor_is_dead()
|
||||
|
||||
def _check_if_any_actor_is_dead(self):
|
||||
if not self.workers:
|
||||
return
|
||||
|
||||
dead_actors = []
|
||||
for actor in self.workers:
|
||||
actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
|
||||
if actor_state["State"] == "DEAD":
|
||||
dead_actors.append(actor)
|
||||
if dead_actors:
|
||||
raise RuntimeError("At least one Worker is dead. "
|
||||
f"Dead Workers: {dead_actors}. ")
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [({
|
||||
"VLLM_INSTANCE_ID":
|
||||
VLLM_INSTANCE_ID,
|
||||
"VLLM_TRACE_FUNCTION":
|
||||
str(envs.VLLM_TRACE_FUNCTION),
|
||||
}, ) for (_, _) in worker_node_and_gpu_ids]
|
||||
return all_args_to_update_environment_variables
|
||||
|
||||
|
||||
class RayXPUExecutorAsync(RayXPUExecutor, DistributedGPUExecutorAsync):
|
||||
class RayXPUExecutorAsync(RayXPUExecutor, RayGPUExecutorAsync):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.driver_exec_method = make_async(self.driver_worker.execute_method)
|
||||
|
||||
async def _driver_execute_model_async(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
return await self.driver_exec_method("execute_model",
|
||||
execute_model_req)
|
||||
|
||||
async def _start_worker_execution_loop(self):
|
||||
coros = [
|
||||
worker.execute_method.remote("start_worker_execution_loop")
|
||||
for worker in self.workers
|
||||
]
|
||||
return await asyncio.gather(*coros)
|
||||
self.pp_locks: Optional[List[asyncio.Lock]] = None
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
import dataclasses
|
||||
import time
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
|
||||
TypeVar)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
|
@ -20,7 +23,7 @@ from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
|||
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
|
||||
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase,
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict,
|
||||
_init_attn_metadata_from_tensor_dict,
|
||||
|
@ -37,6 +40,8 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
|||
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
|
||||
]
|
||||
|
||||
TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForXPU(ModelRunnerInputBase):
|
||||
|
@ -46,11 +51,40 @@ class ModelInputForXPU(ModelRunnerInputBase):
|
|||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
|
||||
virtual_engine: Optional[int] = None
|
||||
seq_lens: Optional[List[int]] = None
|
||||
query_lens: Optional[List[int]] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type[TModelInputForXPU],
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> TModelInputForXPU:
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU):
|
||||
"""
|
||||
Used by the ModelRunner.
|
||||
"""
|
||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
|
@ -62,10 +96,10 @@ class ModelInputForXPU(ModelRunnerInputBase):
|
|||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type["ModelInputForXPU"],
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "ModelInputForXPU":
|
||||
) -> "ModelInputForXPUWithSamplingMetadata":
|
||||
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
|
@ -73,7 +107,230 @@ class ModelInputForXPU(ModelRunnerInputBase):
|
|||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
|
||||
|
||||
def __init__(self,
|
||||
runner: "XPUModelRunner",
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
self.runner = runner
|
||||
self.model_input_cls = self.runner._model_input_cls
|
||||
self.attn_backend = self.runner.attn_backend
|
||||
self.sliding_window = self.runner.sliding_window
|
||||
self.block_size = self.runner.block_size
|
||||
self.device = self.runner.device
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||
self.seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
def build(self) -> ModelInputForXPU:
|
||||
is_prompt = self.seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs) = self._prepare_prompt(
|
||||
self.seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
attn_metadata) = self._prepare_decode(
|
||||
self.seq_group_metadata_list)
|
||||
seq_lens = []
|
||||
multi_modal_kwargs = None
|
||||
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
seq_lens=seq_lens,
|
||||
query_lens=seq_lens,
|
||||
)
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||
BatchedTensorInputs]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
multi_modal_inputs_list: List[MultiModalInputs] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = len(prompt_tokens)
|
||||
|
||||
seq_lens.append(seq_len) # Prompt token num
|
||||
input_tokens.extend(prompt_tokens) # Token ids
|
||||
|
||||
# Token position ids
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
input_positions.extend(list(range(computed_len, seq_len)))
|
||||
|
||||
if seq_group_metadata.block_tables is None:
|
||||
# During memory profiling, the block tables are not initialized
|
||||
# yet. In this case, we just use a dummy slot mapping.
|
||||
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||
continue
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||
# where start_idx is max(0, seq_len - sliding_window).
|
||||
# For example, if the prompt len is 10, sliding window is 8, and
|
||||
# block size is 4, the first two tokens are masked and the slot
|
||||
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
start_idx = 0
|
||||
if self.sliding_window is not None:
|
||||
start_idx = max(0, seq_len - self.sliding_window)
|
||||
|
||||
for i in range(computed_len, seq_len):
|
||||
if i < start_idx:
|
||||
slot_mapping.append(_PAD_SLOT_ID)
|
||||
continue
|
||||
|
||||
block_number = block_table[i //
|
||||
self.block_size] # type: ignore
|
||||
block_offset = i % self.block_size # type: ignore
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
num_prompt_tokens = len(input_tokens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
|
||||
max_seqlen = max(seq_lens)
|
||||
tmp = [0]
|
||||
tmp.extend(seq_lens)
|
||||
seqlen = torch.tensor(tmp)
|
||||
seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seqlen_q=seqlen_q,
|
||||
max_seqlen=max_seqlen,
|
||||
seq_lens_tensor=torch.tensor([]),
|
||||
max_decode_seq_len=0,
|
||||
num_prefills=len(seq_lens),
|
||||
num_prefill_tokens=num_prompt_tokens,
|
||||
num_decode_tokens=0,
|
||||
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
|
||||
)
|
||||
|
||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
block_tables: List[List[int]] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append(generation_token)
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append(position)
|
||||
|
||||
seq_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len, self.sliding_window)
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
if self.sliding_window is not None:
|
||||
sliding_window_blocks = (self.sliding_window //
|
||||
self.block_size)
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
block_tables.append(block_table)
|
||||
|
||||
max_decode_seq_len = max(seq_lens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
block_tables = make_tensor_with_pad(
|
||||
block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seqlen_q=torch.tensor([]),
|
||||
max_seqlen=0,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=len(input_tokens),
|
||||
num_prefills=0,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
return (
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
|
||||
class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = (
|
||||
ModelInputForXPUWithSamplingMetadata)
|
||||
_builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -84,30 +341,32 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
|||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
return_hidden_states: bool = False,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
self.cache_config = cache_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.observability_config = observability_config
|
||||
if self.observability_config is not None:
|
||||
print(f"observability_config is {self.observability_config}")
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.device_config = device_config
|
||||
self.device = self.device_config.device
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
|
@ -203,166 +462,68 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
|||
# Run the model with the dummy inputs.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches = [None] * num_layers
|
||||
model_input = self.prepare_model_input(seqs)
|
||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||
model_input = self.prepare_model_input(
|
||||
seqs, finished_requests_ids=finished_requests_ids)
|
||||
self.execute_model(model_input, kv_caches)
|
||||
torch.xpu.synchronize()
|
||||
return
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> ModelInputForXPU:
|
||||
return (ModelInputForXPU.from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str,
|
||||
Any]) -> ModelInputForXPUWithSamplingMetadata:
|
||||
return (
|
||||
ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
))
|
||||
|
||||
def _prepare_model_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForXPUWithSamplingMetadata:
|
||||
"""Helper method to prepare the model input based on a given sequence
|
||||
group. Prepares metadata needed for the base model forward pass but not
|
||||
metadata for possible additional steps, e.g., sampling.
|
||||
|
||||
"""
|
||||
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
builder.add_seq_group(seq_group_metadata)
|
||||
|
||||
return builder.build() # type: ignore
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForXPU:
|
||||
multi_modal_kwargs = None
|
||||
if self.is_driver_worker:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs
|
||||
) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
|
||||
seq_lens = []
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
# subquery_lens is not needed if chunked prefill is not
|
||||
# supported. Since CPU worker doesn't support chunked prefill
|
||||
# just use seq_lens instead.
|
||||
seq_lens,
|
||||
) -> ModelInputForXPUWithSamplingMetadata:
|
||||
"""Prepare the model input based on a given sequence group, including
|
||||
metadata for the sampling step.
|
||||
|
||||
"""
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
# Sampling metadata is only required for the final pp group
|
||||
generators = self.get_generators(finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||
model_input.seq_lens,
|
||||
model_input.query_lens,
|
||||
self.device,
|
||||
pin_memory=False,
|
||||
generators=self.get_generators(finished_requests_ids))
|
||||
# Broadcast the metadata.
|
||||
metadata_dict = {
|
||||
"input_tokens": input_tokens,
|
||||
"input_positions": input_positions,
|
||||
"selected_token_indices":
|
||||
sampling_metadata.selected_token_indices,
|
||||
"multi_modal_kwargs": multi_modal_kwargs,
|
||||
}
|
||||
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
else:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
input_tokens = metadata_dict.pop("input_tokens")
|
||||
input_positions = metadata_dict.pop("input_positions")
|
||||
selected_token_indices = metadata_dict.pop(
|
||||
"selected_token_indices")
|
||||
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
|
||||
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
num_prompts=0,
|
||||
)
|
||||
generators=generators)
|
||||
|
||||
return ModelInputForXPU(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
return dataclasses.replace(model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
multi_modal_kwargs=multi_modal_kwargs)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
block_tables: List[List[int]] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append(generation_token)
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append(position)
|
||||
|
||||
seq_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len, self.sliding_window)
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
if self.sliding_window is not None:
|
||||
sliding_window_blocks = (self.sliding_window //
|
||||
self.block_size)
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
block_tables.append(block_table)
|
||||
|
||||
max_decode_seq_len = max(seq_lens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
block_tables = make_tensor_with_pad(
|
||||
block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seqlen_q=None,
|
||||
max_seqlen=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=len(input_tokens),
|
||||
num_prefills=0,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
return (
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
)
|
||||
virtual_engine=virtual_engine)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForXPU,
|
||||
model_input: ModelInputForXPUWithSamplingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
|
@ -372,20 +533,21 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
|||
"XPUModelRunner does not support multi-step execution.")
|
||||
|
||||
model_executable = self.model
|
||||
execute_model_kwargs = {
|
||||
"input_ids":
|
||||
model_input.input_tokens,
|
||||
"positions":
|
||||
model_input.input_positions,
|
||||
"kv_caches":
|
||||
kv_caches,
|
||||
"attn_metadata":
|
||||
model_input.attn_metadata,
|
||||
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||
device=self.device),
|
||||
}
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start_time = time.time()
|
||||
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||
device=self.device))
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end_time = time.time()
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
|
@ -396,109 +558,19 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
|||
return []
|
||||
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
output: SamplerOutput = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time
|
||||
and output is not None):
|
||||
model_forward_time = (model_forward_end_time -
|
||||
model_forward_start_time)
|
||||
# If there are multiple workers, we are still tracking the latency
|
||||
# from the start time of the driver worker to the end time of the
|
||||
# driver worker. The model forward time will then end up covering
|
||||
# the communication time as well.
|
||||
output.model_forward_time = model_forward_time
|
||||
|
||||
return [output]
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||
BatchedTensorInputs]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
multi_modal_inputs_list: List[MultiModalInputs] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = len(prompt_tokens)
|
||||
|
||||
seq_lens.append(seq_len) # Prompt token num
|
||||
input_tokens.extend(prompt_tokens) # Token ids
|
||||
|
||||
# Token position ids
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
input_positions.extend(list(range(computed_len, seq_len)))
|
||||
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
if seq_group_metadata.block_tables is None:
|
||||
# During memory profiling, the block tables are not initialized
|
||||
# yet. In this case, we just use a dummy slot mapping.
|
||||
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||
continue
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||
# where start_idx is max(0, seq_len - sliding_window).
|
||||
# For example, if the prompt len is 10, sliding window is 8, and
|
||||
# block size is 4, the first two tokens are masked and the slot
|
||||
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
start_idx = 0
|
||||
if self.sliding_window is not None:
|
||||
start_idx = max(0, seq_len - self.sliding_window)
|
||||
|
||||
for i in range(computed_len, seq_len):
|
||||
if i < start_idx:
|
||||
slot_mapping.append(_PAD_SLOT_ID)
|
||||
continue
|
||||
|
||||
block_number = block_table[i //
|
||||
self.block_size] # type: ignore
|
||||
block_offset = i % self.block_size # type: ignore
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
num_prompt_tokens = len(input_tokens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
|
||||
max_seqlen = max(seq_lens)
|
||||
tmp = [0]
|
||||
tmp.extend(seq_lens)
|
||||
seqlen = torch.tensor(tmp)
|
||||
seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seqlen_q=seqlen_q,
|
||||
max_seqlen=max_seqlen,
|
||||
seq_lens_tensor=None,
|
||||
max_decode_seq_len=None,
|
||||
num_prefills=len(seq_lens),
|
||||
num_prefill_tokens=num_prompt_tokens,
|
||||
num_decode_tokens=0,
|
||||
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
|
||||
)
|
||||
|
||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs)
|
||||
|
|
|
@ -9,8 +9,8 @@ import torch
|
|||
import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
|
@ -46,7 +46,6 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
|||
rank: int,
|
||||
distributed_init_method: str,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
|
@ -73,8 +72,6 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
|||
assert rank % parallel_config.tensor_parallel_size == 0, \
|
||||
"Driver worker should be rank 0 of tensor parallel group."
|
||||
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.model_runner = XPUModelRunner( # type: ignore
|
||||
model_config,
|
||||
parallel_config,
|
||||
|
@ -85,7 +82,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
|||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
multimodal_config=multimodal_config,
|
||||
observability_config=self.observability_config,
|
||||
)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
|
|
Loading…
Reference in New Issue