# default base image ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" FROM $BASE_IMAGE ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" RUN echo "Base image is $BASE_IMAGE" # BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" # BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ARG FA_GFX_ARCHS="gfx90a;gfx942" RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" ARG FA_BRANCH="ae7928c" RUN echo "FA_BRANCH is $FA_BRANCH" # whether to build flash-attention # if 0, will not build flash attention # this is useful for gfx target where flash-attention is not supported # In that case, we need to use the python reference attention implementation in vllm ARG BUILD_FA="1" # whether to build triton on rocm ARG BUILD_TRITON="1" # Install some basic utilities RUN apt-get update && apt-get install python3 python3-pip -y # Install some basic utilities RUN apt-get update && apt-get install -y \ curl \ ca-certificates \ sudo \ git \ bzip2 \ libx11-6 \ build-essential \ wget \ unzip \ nvidia-cuda-toolkit \ tmux \ && rm -rf /var/lib/apt/lists/* ### Mount Point ### # When launching the container, mount the code directory to /app ARG APP_MOUNT=/vllm-workspace VOLUME [ ${APP_MOUNT} ] WORKDIR ${APP_MOUNT} RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin: ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: # Install ROCm flash-attention RUN if [ "$BUILD_FA" = "1" ]; then \ mkdir libs \ && cd libs \ && git clone https://github.com/ROCm/flash-attention.git \ && cd flash-attention \ && git checkout ${FA_BRANCH} \ && git submodule update --init \ && export GPU_ARCHS=${FA_GFX_ARCHS} \ && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \ patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \ && python3 setup.py install \ && cd ..; \ fi # Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt. # Manually removed it so that later steps of numpy upgrade can continue RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi # build triton RUN if [ "$BUILD_TRITON" = "1" ]; then \ mkdir -p libs \ && cd libs \ && pip uninstall -y triton \ && git clone https://github.com/ROCm/triton.git \ && cd triton/python \ && pip3 install . \ && cd ../..; \ fi WORKDIR /vllm-workspace COPY . . RUN python3 -m pip install --upgrade pip numba # make sure punica kernels are built (for LoRA) ENV VLLM_INSTALL_PUNICA_KERNELS=1 RUN --mount=type=cache,target=/root/.cache/pip \ pip install -U -r requirements-rocm.txt \ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ && python3 setup.py install \ && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ && cd .. RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3 CMD ["/bin/bash"]