mirror of https://github.com/vllm-project/vllm
[ROCm][V1] Add intial ROCm support to V1 (#12790)
This commit is contained in:
parent
cbc40128eb
commit
ba59b78a9c
|
@ -0,0 +1,16 @@
|
|||
# Common dependencies
|
||||
-r requirements-common.txt
|
||||
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.2
|
||||
torch==2.5.1
|
||||
torchvision==0.20.1
|
||||
torchaudio==2.5.1
|
||||
|
||||
cmake>=3.26
|
||||
ninja
|
||||
packaging
|
||||
setuptools>=61
|
||||
setuptools-scm>=8
|
||||
wheel
|
||||
jinja2
|
||||
amdsmi==6.2.4
|
|
@ -718,7 +718,8 @@ if triton.__version__ >= "2.1.0":
|
|||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None):
|
||||
sliding_window=None,
|
||||
sm_scale=None):
|
||||
|
||||
q_dtype_is_f32 = q.dtype is torch.float32
|
||||
# need to reduce num. blocks when using fp32
|
||||
|
@ -759,7 +760,8 @@ if triton.__version__ >= "2.1.0":
|
|||
# round up Lk to a power of 2 - this is required for Triton block size
|
||||
Lk_padded = triton.next_power_of_2(Lk)
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
num_queries_per_kv = q.shape[1] // k.shape[1]
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
|
@ -29,12 +28,6 @@ try:
|
|||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._rocm_C with %r", e)
|
||||
|
||||
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
|
||||
logger.warning("`fork` method is not supported by ROCm. "
|
||||
"VLLM_WORKER_MULTIPROC_METHOD is overridden to"
|
||||
" `spawn` instead.")
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
# Models not supported by ROCm.
|
||||
_ROCM_UNSUPPORTED_MODELS: List[str] = []
|
||||
|
||||
|
@ -84,6 +77,9 @@ class RocmPlatform(Platform):
|
|||
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
== _Backend.FLASH_ATTN else selected_backend)
|
||||
if envs.VLLM_USE_V1:
|
||||
logger.info("Using ROCm Attention backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend"
|
||||
if selected_backend == _Backend.ROCM_FLASH:
|
||||
if not cls.has_device_capability(90):
|
||||
# not Instinct series GPUs.
|
||||
|
@ -102,7 +98,11 @@ class RocmPlatform(Platform):
|
|||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return torch.cuda.get_device_name(device_id)
|
||||
# NOTE: When using V1 this function is called when overriding the
|
||||
# engine args. Calling torch.cuda.get_device_name(device_id) here
|
||||
# will result in the ROCm context being initialized before other
|
||||
# processes can be created.
|
||||
return "AMD"
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
|
@ -129,15 +129,30 @@ class RocmPlatform(Platform):
|
|||
scheduler_config = vllm_config.scheduler_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if scheduler_config.is_multi_step:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||
if envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Multi-step scheduling is not supported (and not "
|
||||
"needed) on VLLM V1. Please launch without "
|
||||
"--num-scheduler-steps.")
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||
elif vllm_config.speculative_config:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
parallel_config.sd_worker_cls = \
|
||||
"vllm.worker.worker.Worker"
|
||||
if envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not yet supported on VLLM V1."
|
||||
)
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
parallel_config.sd_worker_cls = \
|
||||
"vllm.worker.worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
@classmethod
|
||||
def verify_model_arch(cls, model_arch: str) -> None:
|
||||
|
|
|
@ -12,8 +12,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import get_flash_attn_version
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Attention layer with PagedAttention on rocm"""
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ROCmAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_ATTN_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["ROCmAttentionImpl"]:
|
||||
return ROCmAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return FlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class ROCmAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"ROCmAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by ROCmAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"ROCmAttentionImpl")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# TODO(sage): Refactor the context_attention_fwd kernel so that this
|
||||
# overhead can be removed
|
||||
context_lens = torch.empty_like(attn_metadata.seq_lens)
|
||||
batch_size = len(attn_metadata.query_start_loc) - 1
|
||||
assert len(context_lens) == batch_size
|
||||
for i in range(batch_size):
|
||||
query_start = attn_metadata.query_start_loc[i]
|
||||
query_end = attn_metadata.query_start_loc[i + 1]
|
||||
context_lens[i] = attn_metadata.seq_lens[i] - (query_end -
|
||||
query_start)
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
context_attention_fwd(q=query[:num_actual_tokens],
|
||||
k=key[:num_actual_tokens],
|
||||
v=value[:num_actual_tokens],
|
||||
o=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
b_loc=attn_metadata.block_table,
|
||||
b_start_loc=attn_metadata.query_start_loc,
|
||||
b_seq_len=attn_metadata.seq_lens,
|
||||
b_ctx_len=context_lens,
|
||||
max_input_len=attn_metadata.max_query_len,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale)
|
||||
return output
|
Loading…
Reference in New Issue