mirror of https://github.com/vllm-project/vllm
[Bugfix][CI/Build][Hardware][AMD] Fix AMD tests, add HF cache, update CK FA, add partially supported model notes (#6543)
This commit is contained in:
parent
683e3cb9c4
commit
06d6c5fe9f
|
@ -66,11 +66,18 @@ trap remove_docker_container EXIT
|
|||
|
||||
echo "--- Running container"
|
||||
|
||||
HF_CACHE="$(realpath ~)/huggingface"
|
||||
mkdir -p ${HF_CACHE}
|
||||
HF_MOUNT="/root/.cache/huggingface"
|
||||
|
||||
docker run \
|
||||
--device /dev/kfd --device /dev/dri \
|
||||
--network host \
|
||||
--shm-size=16gb \
|
||||
--rm \
|
||||
-e HF_TOKEN \
|
||||
-v ${HF_CACHE}:${HF_MOUNT} \
|
||||
-e HF_HOME=${HF_MOUNT} \
|
||||
--name ${container_name} \
|
||||
${image_name} \
|
||||
/bin/bash -c "${@}"
|
||||
|
|
|
@ -44,7 +44,8 @@ steps:
|
|||
mirror_hardwares: [amd]
|
||||
fast_check: true
|
||||
commands:
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
# This flashinfer installation will fail on AMD ROCm, so it is set as optional.
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl || true
|
||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
|
|
|
@ -33,7 +33,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
|
|||
# versions are derived from Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
|
||||
|
||||
#
|
||||
# Try to find python package with an executable that exactly matches
|
||||
|
@ -101,7 +101,7 @@ elseif(HIP_FOUND)
|
|||
# ROCm 5.X and 6.X
|
||||
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
|
||||
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
|
||||
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} "
|
||||
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
|
||||
"expected for ROCm build, saw ${Torch_VERSION} instead.")
|
||||
endif()
|
||||
else()
|
||||
|
|
|
@ -4,18 +4,21 @@ ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
|
|||
# Default ROCm ARCHes to build vLLM for.
|
||||
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
|
||||
|
||||
# Whether to build CK-based flash-attention
|
||||
# If 0, will not build flash attention
|
||||
# This is useful for gfx target where flash-attention is not supported
|
||||
# (i.e. those that do not appear in `FA_GFX_ARCHS`)
|
||||
# Triton FA is used by default on ROCm now so this is unnecessary.
|
||||
# Whether to install CK-based flash-attention
|
||||
# If 0, will not install flash-attention
|
||||
ARG BUILD_FA="1"
|
||||
# If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL`
|
||||
# If this succeeds, we use the downloaded wheel and skip building flash-attention.
|
||||
# Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the
|
||||
# architectures specified in `FA_GFX_ARCHS`
|
||||
ARG TRY_FA_WHEEL="1"
|
||||
ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl"
|
||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||
ARG FA_BRANCH="ae7928c"
|
||||
ARG FA_BRANCH="23a2b1c2"
|
||||
|
||||
# Whether to build triton on rocm
|
||||
ARG BUILD_TRITON="1"
|
||||
ARG TRITON_BRANCH="0ef1848"
|
||||
ARG TRITON_BRANCH="e0fc12c"
|
||||
|
||||
### Base image build stage
|
||||
FROM $BASE_IMAGE AS base
|
||||
|
@ -43,15 +46,15 @@ RUN apt-get update && apt-get install -y \
|
|||
ARG APP_MOUNT=/vllm-workspace
|
||||
WORKDIR ${APP_MOUNT}
|
||||
|
||||
RUN pip install --upgrade pip
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
# Remove sccache so it doesn't interfere with ccache
|
||||
# TODO: implement sccache support across components
|
||||
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
|
||||
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
|
||||
# Install torch == 2.5.0 on ROCm
|
||||
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||
*"rocm-6.1"*) \
|
||||
pip uninstall -y torch torchaudio torchvision \
|
||||
&& pip install --no-cache-dir --pre \
|
||||
python3 -m pip uninstall -y torch torchaudio torchvision \
|
||||
&& python3 -m pip install --no-cache-dir --pre \
|
||||
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
|
||||
torchvision==0.20.0.dev20240710 \
|
||||
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
|
||||
|
@ -70,24 +73,31 @@ ENV CCACHE_DIR=/root/.cache/ccache
|
|||
FROM base AS build_amdsmi
|
||||
# Build amdsmi wheel always
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
&& pip wheel . --wheel-dir=/install
|
||||
&& python3 -m pip wheel . --wheel-dir=/install
|
||||
|
||||
|
||||
### Flash-Attention wheel build stage
|
||||
FROM base AS build_fa
|
||||
ARG BUILD_FA
|
||||
ARG TRY_FA_WHEEL
|
||||
ARG FA_WHEEL_URL
|
||||
ARG FA_GFX_ARCHS
|
||||
ARG FA_BRANCH
|
||||
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
|
||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||
if [ "$BUILD_FA" = "1" ]; then \
|
||||
mkdir -p libs \
|
||||
&& cd libs \
|
||||
&& git clone https://github.com/ROCm/flash-attention.git \
|
||||
&& cd flash-attention \
|
||||
&& git checkout "${FA_BRANCH}" \
|
||||
&& git submodule update --init \
|
||||
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
|
||||
if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \
|
||||
# If a suitable wheel exists, we download it instead of building FA
|
||||
mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \
|
||||
else \
|
||||
mkdir -p libs \
|
||||
&& cd libs \
|
||||
&& git clone https://github.com/ROCm/flash-attention.git \
|
||||
&& cd flash-attention \
|
||||
&& git checkout "${FA_BRANCH}" \
|
||||
&& git submodule update --init \
|
||||
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
|
||||
fi; \
|
||||
# Create an empty directory otherwise as later build stages expect one
|
||||
else mkdir -p /install; \
|
||||
fi
|
||||
|
@ -126,7 +136,7 @@ RUN case "$(which python3)" in \
|
|||
|
||||
# Package upgrades for useful functionality or to avoid dependency issues
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install --upgrade numba scipy huggingface-hub[cli]
|
||||
python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
|
||||
|
||||
# Make sure punica kernels are built (for LoRA)
|
||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
|
@ -137,7 +147,7 @@ ENV TOKENIZERS_PARALLELISM=false
|
|||
|
||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -U -r requirements-rocm.txt \
|
||||
python3 -m pip install -Ur requirements-rocm.txt \
|
||||
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||
*"rocm-6.1"*) \
|
||||
# Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
|
||||
|
@ -153,7 +163,7 @@ RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
|
|||
mkdir -p libs \
|
||||
&& cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& pip uninstall -y amdsmi;
|
||||
&& python3 -m pip uninstall -y amdsmi;
|
||||
|
||||
# Copy triton wheel(s) into final image if they were built
|
||||
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
|
||||
|
@ -161,7 +171,7 @@ RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
|
|||
&& if ls /install/*.whl; then \
|
||||
cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& pip uninstall -y triton; fi
|
||||
&& python3 -m pip uninstall -y triton; fi
|
||||
|
||||
# Copy flash-attn wheel(s) into final image if they were built
|
||||
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
|
||||
|
@ -169,11 +179,11 @@ RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
|
|||
&& if ls /install/*.whl; then \
|
||||
cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& pip uninstall -y flash-attn; fi
|
||||
&& python3 -m pip uninstall -y flash-attn; fi
|
||||
|
||||
# Install wheels that were built to the final image
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if ls libs/*.whl; then \
|
||||
pip install libs/*.whl; fi
|
||||
python3 -m pip install libs/*.whl; fi
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
|
|
@ -90,12 +90,12 @@ Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTor
|
|||
|
||||
Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton <https://github.com/ROCm/triton/blob/triton-mlir/README.md>`_
|
||||
|
||||
2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm>`_
|
||||
2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/ck_tile>`_
|
||||
|
||||
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
|
||||
Install ROCm's flash attention (v2.5.9.post1) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support>`_
|
||||
Alternatively, wheels intended for vLLM use can be accessed under the releases.
|
||||
|
||||
.. note::
|
||||
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
|
||||
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||
|
||||
3. Build vLLM.
|
||||
|
@ -110,5 +110,6 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl
|
|||
.. tip::
|
||||
|
||||
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
|
||||
- Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support.
|
||||
- To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
|
||||
- The ROCm version of PyTorch, ideally, should match the ROCm driver version.
|
||||
|
|
|
@ -2,5 +2,9 @@
|
|||
-r requirements-common.txt
|
||||
|
||||
# Dependencies for AMD GPUs
|
||||
awscli
|
||||
boto3
|
||||
botocore
|
||||
ray >= 2.10.0
|
||||
peft
|
||||
pytest-asyncio
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
from vllm.utils import is_hip
|
||||
|
||||
from ..utils import compare_two_settings
|
||||
|
||||
|
||||
def test_cpu_offload():
|
||||
compare_two_settings("meta-llama/Llama-2-7b-hf", [],
|
||||
["--cpu-offload-gb", "4"])
|
||||
compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t",
|
||||
[], ["--cpu-offload-gb", "1"])
|
||||
if not is_hip():
|
||||
# compressed-tensors quantization is currently not supported in ROCm.
|
||||
compare_two_settings(
|
||||
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [],
|
||||
["--cpu-offload-gb", "1"])
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
|
@ -5,6 +6,7 @@ from transformers import AutoTokenizer
|
|||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from .utils import check_logprobs_close
|
||||
|
@ -22,6 +24,12 @@ IMAGE_TOKEN_ID = 257152
|
|||
|
||||
models = ["google/paligemma-3b-mix-224"]
|
||||
|
||||
# ROCm Triton FA can run into compilation issues with these models due to,
|
||||
# excessive use of shared memory. Use other backends in the meantime.
|
||||
# FIXME (mattwong, gshtrasb, hongxiayan)
|
||||
if is_hip():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
Optional[SampleLogprobs]],
|
||||
|
@ -130,7 +138,15 @@ def run_test(
|
|||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["float", "half"])
|
||||
@pytest.mark.parametrize("dtype", [
|
||||
pytest.param(
|
||||
"float",
|
||||
marks=pytest.mark.skipif(
|
||||
is_hip(),
|
||||
reason=
|
||||
"ROCm FA does not yet fully support 32-bit precision on PaliGemma")
|
||||
), "half"
|
||||
])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
|
@ -6,7 +7,7 @@ from transformers import AutoTokenizer
|
|||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_cpu
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from .utils import check_logprobs_close
|
||||
|
@ -47,6 +48,12 @@ target_dtype = "half"
|
|||
if is_cpu():
|
||||
target_dtype = "bfloat16"
|
||||
|
||||
# ROCm Triton FA can run into shared memory issues with these models,
|
||||
# use other backends in the meantime
|
||||
# FIXME (mattwong, gshtrasb, hongxiayan)
|
||||
if is_hip():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
|
|
|
@ -275,6 +275,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||
triton_attention)
|
||||
self.attn_func = triton_attention
|
||||
logger.debug("Using Triton FA in ROCmBackend")
|
||||
if self.sliding_window != (-1, -1):
|
||||
logger.warning("ROCm Triton FA does not currently support "
|
||||
"sliding window attention. If using half "
|
||||
"precision, please try using the ROCm CK "
|
||||
"FA backend instead by setting the env var "
|
||||
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
else:
|
||||
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
|
||||
# either
|
||||
|
@ -434,6 +440,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
|
||||
# common code for prefill
|
||||
|
|
|
@ -87,13 +87,24 @@ _ROCM_UNSUPPORTED_MODELS: List[str] = []
|
|||
|
||||
# Models partially supported by ROCm.
|
||||
# Architecture -> Reason.
|
||||
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
|
||||
"Triton flash attention. For half-precision SWA support, "
|
||||
"please use CK flash attention by setting "
|
||||
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||
"Qwen2ForCausalLM":
|
||||
"Sliding window attention is not yet supported in ROCm's flash attention",
|
||||
_ROCM_SWA_REASON,
|
||||
"MistralForCausalLM":
|
||||
"Sliding window attention is not yet supported in ROCm's flash attention",
|
||||
_ROCM_SWA_REASON,
|
||||
"MixtralForCausalLM":
|
||||
"Sliding window attention is not yet supported in ROCm's flash attention",
|
||||
_ROCM_SWA_REASON,
|
||||
"PaliGemmaForConditionalGeneration":
|
||||
("ROCm flash attention does not yet "
|
||||
"fully support 32-bit precision on PaliGemma"),
|
||||
"Phi3VForCausalLM":
|
||||
("ROCm Triton flash attention may run into compilation errors due to "
|
||||
"excessive use of shared memory. If this happens, disable Triton FA "
|
||||
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,14 @@ from typing import List, Optional
|
|||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
|
||||
try:
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
except ModuleNotFoundError:
|
||||
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
|
||||
from vllm.attention.backends.rocm_flash_attn import (
|
||||
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
|
|
Loading…
Reference in New Issue