[ROCm] add support to ROCm 6.0 and MI300 (#2274)

This commit is contained in:
Hongxia Yang 2024-01-26 15:41:10 -05:00 committed by GitHub
parent 5265631d15
commit 6b7de1a030
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 96 additions and 13 deletions

View File

@ -1,4 +1,24 @@
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
# default base image
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
FROM $BASE_IMAGE
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
RUN echo "Base image is $BASE_IMAGE"
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
# this does not always work for all rocm versions
RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \
echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
ARG FA_BRANCH="3d2b6f5"
RUN echo "FA_BRANCH is $FA_BRANCH"
# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y
@ -37,17 +57,23 @@ RUN mkdir libs \
&& cd libs \
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
&& cd flash-attention \
&& git checkout 3d2b6f5 \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \
&& patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
&& python3 setup.py install \
&& cd ..
COPY ./ /app/vllm
RUN python3 -m pip install --upgrade pip
RUN pip install xformers==0.0.23 --no-deps
RUN python3 -m pip install xformers==0.0.23 --no-deps
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
# Manually removed it so that later steps of numpy upgrade can continue
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
RUN cd /app \
&& cd vllm \

View File

@ -26,7 +26,8 @@ Please register [here](https://lu.ma/ygxbpzhl) and join us!
---
*Latest News* 🔥
- [2023/12] Added ROCm support to vLLM.
- [2024/01] Added ROCm 6.0 support to vLLM.
- [2023/12] Added ROCm 5.7 support to vLLM.
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!

View File

@ -5,3 +5,6 @@
int get_device_attribute(
int attribute,
int device_id);
int get_max_shared_memory_per_block_device_attribute(
int device_id);

View File

@ -1,5 +1,6 @@
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#endif
int get_device_attribute(
int attribute,
@ -15,3 +16,20 @@ int get_device_attribute(
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
return value;
}
int get_max_shared_memory_per_block_device_attribute(
int device_id)
{
int attribute;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
#ifdef USE_ROCM
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
#else
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
#endif
return get_device_attribute(attribute, device_id);
}

View File

@ -81,4 +81,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"get_device_attribute",
&get_device_attribute,
"Gets the specified device attribute.");
cuda_utils.def(
"get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute.");
}

View File

@ -11,10 +11,10 @@ Requirements
------------
* OS: Linux
* Python: 3.8 -- 3.11 (Verified on 3.10)
* GPU: MI200s
* Python: 3.8 -- 3.11
* GPU: MI200s (gfx90a), MI300 (gfx942)
* Pytorch 2.0.1/2.1.1/2.2
* ROCm 5.7
* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9)
Installation options:
@ -27,6 +27,8 @@ Installation options:
(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image
---------------------------------------------------------------------------
This option is for ROCm 5.7 only:
.. code-block:: console
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
@ -50,6 +52,9 @@ Option 2: Build from source
You can build and install vLLM from source:
Below instruction is for ROCm 5.7 only.
At the time of this documentation update, PyTorch on ROCm 6.0 wheel is not yet available on the PyTorch website.
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
@ -95,6 +100,23 @@ You can build and install vLLM from source:
Build a docker image from `Dockerfile.rocm`, and launch a docker container.
The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later versions. It provides flexibility to customize the build of docker image using the following arguments:
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5`
Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
For example, to build docker image for vllm on ROCm 5.7, you can run:
.. code-block:: console
$ docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
-f Dockerfile.rocm -t vllm-rocm .
To build vllm on ROCm 6.0, you can use the default:
.. code-block:: console
$ docker build -f Dockerfile.rocm -t vllm-rocm .
@ -142,3 +164,8 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from
$ cd vllm
$ pip install -U -r requirements-rocm.txt
$ python setup.py install # This may take 5-10 minutes.
.. note::
- You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.

View File

@ -51,6 +51,8 @@ if _is_hip():
"Cannot find ROCM_HOME. ROCm must be available to build the package."
)
NVCC_FLAGS += ["-DUSE_ROCM"]
NVCC_FLAGS += [f"-U__HIP_NO_HALF_CONVERSIONS__"]
NVCC_FLAGS += [f"-U__HIP_NO_HALF_OPERATORS__"]
if _is_cuda() and CUDA_HOME is None:
raise RuntimeError(

View File

@ -112,10 +112,10 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
# the Neuron-X backend does not have the `cuda_utils` module.
from vllm._C import cuda_utils
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
max_shared_mem = cuda_utils.get_device_attribute(
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
max_shared_mem = cuda_utils.get_max_shared_memory_per_block_device_attribute(
gpu)
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py will fail
assert max_shared_mem > 0, "max_shared_mem can not be zero"
return int(max_shared_mem)