[Hardware][AMD][CI/Build][Doc] Upgrade to ROCm 6.1, Dockerfile improvements, test fixes (#5422)

This commit is contained in:
Matt Wong 2024-06-25 17:56:15 -05:00 committed by GitHub
parent bc34937d68
commit dd793d1de5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 257 additions and 120 deletions

View File

@ -32,8 +32,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# versions are derived from Dockerfile.rocm
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0")
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0")
#
# Try to find python package with an executable that exactly matches
@ -98,18 +97,11 @@ elseif(HIP_FOUND)
# .hip extension automatically, HIP must be enabled explicitly.
enable_language(HIP)
# ROCm 5.x
if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X})
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} "
"expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.")
endif()
# ROCm 6.x
if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X})
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} "
"expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.")
# ROCm 5.X and 6.X
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} "
"expected for ROCm build, saw ${Torch_VERSION} instead.")
endif()
else()
message(FATAL_ERROR "Can't find CUDA or HIP installation.")

View File

@ -1,34 +1,35 @@
# default base image
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
# Default ROCm 6.1 base image
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
FROM $BASE_IMAGE
# Tested and supported base rocm/pytorch images
ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \
ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \
ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
# Default ROCm ARCHes to build vLLM for.
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
RUN echo "Base image is $BASE_IMAGE"
ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
ARG FA_BRANCH="ae7928c"
RUN echo "FA_BRANCH is $FA_BRANCH"
# whether to build flash-attention
# if 0, will not build flash attention
# this is useful for gfx target where flash-attention is not supported
# In that case, we need to use the python reference attention implementation in vllm
# Whether to build CK-based flash-attention
# If 0, will not build flash attention
# This is useful for gfx target where flash-attention is not supported
# (i.e. those that do not appear in `FA_GFX_ARCHS`)
# Triton FA is used by default on ROCm now so this is unnecessary.
ARG BUILD_FA="1"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
ARG FA_BRANCH="ae7928c"
# whether to build triton on rocm
# Whether to build triton on rocm
ARG BUILD_TRITON="1"
ARG TRITON_BRANCH="0ef1848"
### Base image build stage
FROM $BASE_IMAGE AS base
# Import arg(s) defined before this build stage
ARG PYTORCH_ROCM_ARCH
# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y
# Install some basic utilities
RUN apt-get update && apt-get install -y \
curl \
ca-certificates \
@ -39,79 +40,159 @@ RUN apt-get update && apt-get install -y \
build-essential \
wget \
unzip \
nvidia-cuda-toolkit \
tmux \
ccache \
&& rm -rf /var/lib/apt/lists/*
### Mount Point ###
# When launching the container, mount the code directory to /app
# When launching the container, mount the code directory to /vllm-workspace
ARG APP_MOUNT=/vllm-workspace
VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT}
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
RUN pip install --upgrade pip
# Remove sccache so it doesn't interfere with ccache
# TODO: implement sccache support across components
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
# Install torch == 2.4.0 on ROCm
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-5.7"*) \
pip uninstall -y torch \
&& pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
--index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \
*"rocm-6.0"*) \
pip uninstall -y torch \
&& pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
--index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \
*"rocm-6.1"*) \
pip uninstall -y torch \
&& pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
*) ;; esac
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
# Install ROCm flash-attention
RUN if [ "$BUILD_FA" = "1" ]; then \
mkdir libs \
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
ENV CCACHE_DIR=/root/.cache/ccache
### AMD-SMI build stage
FROM base AS build_amdsmi
# Build amdsmi wheel always
RUN cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=/install
### Flash-Attention wheel build stage
FROM base AS build_fa
ARG BUILD_FA
ARG FA_GFX_ARCHS
ARG FA_BRANCH
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
RUN --mount=type=cache,target=${CCACHE_DIR} \
if [ "$BUILD_FA" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& git clone https://github.com/ROCm/flash-attention.git \
&& cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git checkout "${FA_BRANCH}" \
&& git submodule update --init \
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
&& if [ "$BASE_IMAGE" = "$ROCm_5_7_BASE" ]; then \
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
&& python3 setup.py install \
&& cd ..; \
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-5.7"*) \
export VLLM_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \
&& patch "${VLLM_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \
*) ;; esac \
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
# Create an empty directory otherwise as later build stages expect one
else mkdir -p /install; \
fi
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
# Manually removed it so that later steps of numpy upgrade can continue
RUN if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
# build triton
RUN if [ "$BUILD_TRITON" = "1" ]; then \
### Triton wheel build stage
FROM base AS build_triton
ARG BUILD_TRITON
ARG TRITON_BRANCH
# Build triton wheel if `BUILD_TRITON = 1`
RUN --mount=type=cache,target=${CCACHE_DIR} \
if [ "$BUILD_TRITON" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& pip uninstall -y triton \
&& git clone https://github.com/ROCm/triton.git \
&& cd triton/python \
&& pip3 install . \
&& cd ../..; \
&& git clone https://github.com/OpenAI/triton.git \
&& cd triton \
&& git checkout "${TRITON_BRANCH}" \
&& cd python \
&& python3 setup.py bdist_wheel --dist-dir=/install; \
# Create an empty directory otherwise as later build stages expect one
else mkdir -p /install; \
fi
WORKDIR /vllm-workspace
### Final vLLM build stage
FROM base AS final
# Import the vLLM development directory from the build context
COPY . .
#RUN python3 -m pip install pynvml # to be removed eventually
RUN python3 -m pip install --upgrade pip numba
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
# Manually remove it so that later steps of numpy upgrade can continue
RUN case "$(which python3)" in \
*"/opt/conda/envs/py_3.9"*) \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
*) ;; esac
# make sure punica kernels are built (for LoRA)
# Package upgrades for useful functionality or to avoid dependency issues
RUN --mount=type=cache,target=/root/.cache/pip \
pip install --upgrade numba scipy huggingface-hub[cli]
# Make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
# Workaround for ray >= 2.10.0
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
# Silences the HF Tokenizers warning
ENV TOKENIZERS_PARALLELISM=false
ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
RUN --mount=type=cache,target=${CCACHE_DIR} \
--mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \
&& if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch; fi \
&& python3 setup.py install \
&& export VLLM_PYTHON_VERSION=$(python -c "import sys; print(str(sys.version_info.major) + str(sys.version_info.minor))") \
&& cp build/lib.linux-x86_64-cpython-${VLLM_PYTHON_VERSION}/vllm/*.so vllm/ \
&& cd ..
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-6.0"*) \
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \
*"rocm-6.1"*) \
# Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \
&& cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \
# Prevent interference if torch bundles its own HIP runtime
&& rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
*) ;; esac \
&& python3 setup.py clean --all \
&& python3 setup.py develop
# Copy amdsmi wheel into final image
RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
mkdir -p libs \
&& cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y amdsmi;
# Copy triton wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
mkdir -p libs \
&& if ls /install/*.whl; then \
cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y triton; fi
# Copy flash-attn wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
mkdir -p libs \
&& if ls /install/*.whl; then \
cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y flash-attn; fi
# Install wheels that were built to the final image
RUN --mount=type=cache,target=/root/.cache/pip \
if ls libs/*.whl; then \
pip install libs/*.whl; fi
CMD ["/bin/bash"]

View File

@ -147,19 +147,23 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
if (${GPU_LANG} STREQUAL "HIP")
#
# `GPU_ARCHES` controls the `--offload-arch` flags.
# `CMAKE_HIP_ARCHITECTURES` is set up by torch and can be controlled
# via the `PYTORCH_ROCM_ARCH` env variable.
#
# If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list,
# if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling
# "rocm_agent_enumerator" in "enable_language(HIP)"
# (in file Modules/CMakeDetermineHIPCompiler.cmake)
#
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH})
else()
set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES})
endif()
#
# Find the intersection of the supported + detected architectures to
# set the module architecture flags.
#
set(VLLM_ROCM_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
set(${GPU_ARCHES})
foreach (_ARCH ${VLLM_ROCM_SUPPORTED_ARCHS})
foreach (_ARCH ${HIP_ARCHITECTURES})
if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
list(APPEND ${GPU_ARCHES} ${_ARCH})
endif()
@ -167,7 +171,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
if(NOT ${GPU_ARCHES})
message(FATAL_ERROR
"None of the detected ROCm architectures: ${CMAKE_HIP_ARCHITECTURES} is"
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
endif()

View File

@ -88,7 +88,7 @@ Option 2: Build from source
- `Pytorch <https://pytorch.org/>`_
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`.
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`.
Alternatively, you can install pytorch using pytorch wheels. You can check Pytorch installation guild in Pytorch `Getting Started <https://pytorch.org/get-started/locally/>`_
@ -126,12 +126,12 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl
$ cd vllm
$ pip install -U -r requirements-rocm.txt
$ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
$ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
.. tip::
- You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
- To use CK flash-attention, please use this flag ``export VLLM_USE_FLASH_ATTN_TRITON=0`` to turn off triton flash attention.
- To use CK flash-attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
- The ROCm version of pytorch, ideally, should match the ROCm driver version.

View File

@ -4,7 +4,7 @@ import pytest
# and debugging.
import ray
from ..utils import VLLM_PATH, RemoteOpenAIServer
from ..utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
@ -12,7 +12,7 @@ MODEL_NAME = "facebook/opt-125m"
@pytest.fixture(scope="module")
def ray_ctx():
ray.init(runtime_env={"working_dir": VLLM_PATH})
ray.init()
yield
ray.shutdown()

View File

@ -1,8 +1,8 @@
import os
import ray
from vllm.utils import cuda_device_count_stateless
import vllm.envs as envs
from vllm.utils import (cuda_device_count_stateless, is_hip,
update_environment_variables)
@ray.remote
@ -12,16 +12,21 @@ class _CUDADeviceCountStatelessTestActor:
return cuda_device_count_stateless()
def set_cuda_visible_devices(self, cuda_visible_devices: str):
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
update_environment_variables(
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
def get_cuda_visible_devices(self):
return os.environ["CUDA_VISIBLE_DEVICES"]
return envs.CUDA_VISIBLE_DEVICES
def test_cuda_device_count_stateless():
"""Test that cuda_device_count_stateless changes return value if
CUDA_VISIBLE_DEVICES is changed."""
if is_hip():
# Set HIP_VISIBLE_DEVICES == CUDA_VISIBLE_DEVICES. Conversion
# is handled by `update_environment_variables`
update_environment_variables(
{"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES})
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
num_gpus=2).remote()
assert sorted(ray.get(

View File

@ -2,7 +2,7 @@ import openai
import pytest
import ray
from ..utils import VLLM_PATH, RemoteOpenAIServer
from ..utils import RemoteOpenAIServer
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
@ -11,7 +11,7 @@ pytestmark = pytest.mark.openai
@pytest.fixture(scope="module")
def ray_ctx():
ray.init(runtime_env={"working_dir": VLLM_PATH})
ray.init()
yield
ray.shutdown()

View File

@ -16,7 +16,7 @@ from openai import BadRequestError
from vllm.transformers_utils.tokenizer import get_tokenizer
from ..utils import VLLM_PATH, RemoteOpenAIServer
from ..utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@ -81,7 +81,7 @@ def zephyr_lora_files():
@pytest.fixture(scope="module")
def ray_ctx():
ray.init(runtime_env={"working_dir": VLLM_PATH})
ray.init()
yield
ray.shutdown()

View File

@ -8,7 +8,7 @@ import ray
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
from ..utils import VLLM_PATH, RemoteOpenAIServer
from ..utils import RemoteOpenAIServer
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
LLAVA_CHAT_TEMPLATE = (Path(__file__).parent.parent.parent /
@ -27,7 +27,7 @@ pytestmark = pytest.mark.openai
@pytest.fixture(scope="module")
def ray_ctx():
ray.init(runtime_env={"working_dir": VLLM_PATH})
ray.init()
yield
ray.shutdown()

View File

@ -15,9 +15,30 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import get_open_port, is_hip
if (not is_hip()):
if is_hip():
from amdsmi import (amdsmi_get_gpu_vram_usage,
amdsmi_get_processor_handles, amdsmi_init,
amdsmi_shut_down)
@contextmanager
def _nvml():
try:
amdsmi_init()
yield
finally:
amdsmi_shut_down()
else:
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
nvmlInit)
nvmlInit, nvmlShutdown)
@contextmanager
def _nvml():
try:
nvmlInit()
yield
finally:
nvmlShutdown()
# Path to root of repository so that utilities can be imported by ray workers
VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))
@ -160,20 +181,25 @@ def error_on_warning():
yield
@_nvml()
def wait_for_gpu_memory_to_clear(devices: List[int],
threshold_bytes: int,
timeout_s: float = 120) -> None:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
nvmlInit()
start_time = time.time()
while True:
output: Dict[int, str] = {}
output_raw: Dict[int, float] = {}
for device in devices:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
if is_hip():
dev_handle = amdsmi_get_processor_handles()[device]
mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
gb_used = mem_info["vram_used"] / 2**10
else:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
output_raw[device] = gb_used
output[device] = f'{gb_used:.02f}'

View File

@ -7,13 +7,15 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
import torch
from transformers import PretrainedConfig, PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
from vllm.tracing import is_otel_installed
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_tpu, is_xpu)
is_hip, is_neuron, is_tpu, is_xpu,
update_environment_variables)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -634,6 +636,12 @@ class ParallelConfig:
self.distributed_executor_backend = backend
logger.info("Defaulting to use %s for distributed inference",
backend)
# If CUDA_VISIBLE_DEVICES is set on ROCm prior to vLLM init,
# propagate changes to HIP_VISIBLE_DEVICES (conversion handled by
# the update_environment_variables function)
if is_hip() and envs.CUDA_VISIBLE_DEVICES:
update_environment_variables(
{"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES})
self._verify_args()

View File

@ -13,7 +13,8 @@ import torch.multiprocessing as mp
import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless
from vllm.utils import (cuda_device_count_stateless,
update_environment_variables)
logger = init_logger(__name__)
@ -24,7 +25,8 @@ def producer(batch_src: Sequence[int],
result_queue,
cuda_visible_devices: Optional[str] = None):
if cuda_visible_devices is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
update_environment_variables(
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
lib = CudaRTLibrary()
for i in batch_src:
@ -56,7 +58,8 @@ def consumer(batch_tgt: Sequence[int],
result_queue,
cuda_visible_devices: Optional[str] = None):
if cuda_visible_devices is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
update_environment_variables(
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
lib = CudaRTLibrary()
for j in batch_tgt:
@ -123,7 +126,7 @@ def can_actually_p2p(
processes for testing all pairs of GPUs in batch. The trick is to reset
the device after each test (which is not available in PyTorch).
""" # noqa
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
# pass the CUDA_VISIBLE_DEVICES to the child process
# to make sure they see the same set of GPUs

View File

@ -11,7 +11,8 @@ from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (cuda_device_count_stateless,
get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async)
get_vllm_instance_id, make_async,
update_environment_variables)
logger = init_logger(__name__)
@ -25,8 +26,9 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = (",".join(
map(str, range(world_size))))
update_environment_variables({
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()

View File

@ -376,6 +376,10 @@ def get_open_port() -> int:
def update_environment_variables(envs: Dict[str, str]):
if is_hip() and "CUDA_VISIBLE_DEVICES" in envs:
# Propagate changes to CUDA_VISIBLE_DEVICES to
# ROCm's HIP_VISIBLE_DEVICES as well
envs["HIP_VISIBLE_DEVICES"] = envs["CUDA_VISIBLE_DEVICES"]
for k, v in envs.items():
if k in os.environ and os.environ[k] != v:
logger.warning(
@ -779,9 +783,14 @@ def _cuda_device_count_stateless(
if not torch.cuda._is_compiled():
return 0
# bypass _device_count_nvml() if rocm (not supported)
nvml_count = -1 if torch.version.hip else torch.cuda._device_count_nvml()
r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
if is_hip():
# ROCm uses amdsmi instead of nvml for stateless device count
# This requires a sufficiently modern version of Torch 2.4.0
raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
torch.cuda, "_device_count_amdsmi")) else -1
else:
raw_count = torch.cuda._device_count_nvml()
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
return r
@ -795,7 +804,6 @@ def cuda_device_count_stateless() -> int:
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)

View File

@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (enable_trace_function_call_for_thread,
from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
update_environment_variables)
logger = init_logger(__name__)
@ -125,6 +125,14 @@ class WorkerWrapperBase:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del os.environ[key]
if is_hip():
hip_env_var = "HIP_VISIBLE_DEVICES"
if hip_env_var in os.environ:
logger.warning(
"Ignoring pre-set environment variable `%s=%s` as "
"%s has also been set, which takes precedence.",
hip_env_var, os.environ[hip_env_var], key)
os.environ.pop(hip_env_var, None)
update_environment_variables(envs)
def init_worker(self, *args, **kwargs):