Cmake based build system (#2830)

This commit is contained in:
bnellnm 2024-03-18 18:38:33 -04:00 committed by GitHub
parent c0c17d4896
commit 9fdf3de346
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 868 additions and 302 deletions

279
CMakeLists.txt Normal file
View File

@ -0,0 +1,279 @@
cmake_minimum_required(VERSION 3.21)
project(vllm_extensions LANGUAGES CXX)
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
#
# Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py.
#
set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
# Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
#
# Supported/expected torch versions for CUDA/ROCm.
#
# Currently, having an incorrect pytorch version results in a warning
# rather than an error.
#
# Note: the CUDA torch version is derived from pyproject.toml and various
# requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from Dockerfile.rocm
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.1.2")
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
#
# Try to find python package with an executable that exactly matches
# `VLLM_PYTHON_EXECUTABLE` and is one of the supported versions.
#
if (VLLM_PYTHON_EXECUTABLE)
find_python_from_executable(${VLLM_PYTHON_EXECUTABLE} "${PYTHON_SUPPORTED_VERSIONS}")
else()
message(FATAL_ERROR
"Please set VLLM_PYTHON_EXECUTABLE to the path of the desired python version"
" before running cmake configure.")
endif()
#
# Update cmake's `CMAKE_PREFIX_PATH` with torch location.
#
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
#
# Import torch cmake configuration.
# Torch also imports CUDA (and partially HIP) languages with some customizations,
# so there is no need to do this explicitly with check_language/enable_language,
# etc.
#
find_package(Torch REQUIRED)
#
# Normally `torch.utils.cpp_extension.CUDAExtension` would add
# `libtorch_python.so` for linking against an extension. Torch's cmake
# configuration does not include this library (presumably since the cmake
# config is used for standalone C++ binaries that link against torch).
# The `libtorch_python.so` library defines some of the glue code between
# torch/python via pybind and is required by VLLM extensions for this
# reason. So, add it by manually using `append_torchlib_if_found` from
# torch's cmake setup.
#
append_torchlib_if_found(torch_python)
#
# Set up GPU language and check the torch version and warn if it isn't
# what is expected.
#
if (NOT HIP_FOUND AND CUDA_FOUND)
set(VLLM_GPU_LANG "CUDA")
if (NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_CUDA})
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_CUDA} "
"expected for CUDA build, saw ${Torch_VERSION} instead.")
endif()
elseif(HIP_FOUND)
set(VLLM_GPU_LANG "HIP")
# Importing torch recognizes and sets up some HIP/ROCm configuration but does
# not let cmake recognize .hip files. In order to get cmake to understand the
# .hip extension automatically, HIP must be enabled explicitly.
enable_language(HIP)
# ROCm 5.x
if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X})
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} "
"expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.")
endif()
# ROCm 6.x
if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X})
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} "
"expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.")
endif()
else()
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
endif()
#
# Override the GPU architectures detected by cmake/torch and filter them by
# the supported versions for the current language.
# The final set of arches is stored in `VLLM_GPU_ARCHES`.
#
override_gpu_arches(VLLM_GPU_ARCHES
${VLLM_GPU_LANG}
"${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")
#
# Query torch for additional GPU compilation flags for the given
# `VLLM_GPU_LANG`.
# The final set of arches is stored in `VLLM_GPU_FLAGS`.
#
get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG})
#
# Set nvcc parallelism.
#
if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
endif()
#
# Define extension targets
#
#
# _C extension
#
set(VLLM_EXT_SRC
"csrc/cache_kernels.cu"
"csrc/attention/attention_kernels.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
"csrc/custom_all_reduce.cu")
endif()
define_gpu_extension_target(
_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
WITH_SOABI)
#
# _moe_C extension
#
set(VLLM_MOE_EXT_SRC
"csrc/moe/moe_ops.cpp"
"csrc/moe/topk_softmax_kernels.cu")
define_gpu_extension_target(
_moe_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_MOE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
WITH_SOABI)
#
# _punica_C extension
#
set(VLLM_PUNICA_EXT_SRC
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu"
"csrc/punica/punica_ops.cc")
#
# Copy GPU compilation flags+update for punica
#
set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS})
list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS
"-D__CUDA_NO_HALF_OPERATORS__"
"-D__CUDA_NO_HALF_CONVERSIONS__"
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
"-D__CUDA_NO_HALF2_OPERATORS__")
#
# Filter out CUDA architectures < 8.0 for punica.
#
if (${VLLM_GPU_LANG} STREQUAL "CUDA")
set(VLLM_PUNICA_GPU_ARCHES)
foreach(ARCH ${VLLM_GPU_ARCHES})
string_to_ver(CODE_VER ${ARCH})
if (CODE_VER GREATER_EQUAL 8.0)
list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH})
endif()
endforeach()
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
endif()
if (VLLM_PUNICA_GPU_ARCHES)
define_gpu_extension_target(
_punica_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_PUNICA_EXT_SRC}
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
WITH_SOABI)
else()
message(WARNING "Unable to create _punica_C target because none of the "
"requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0")
endif()
#
# Add the `default` target which detects which extensions should be
# built based on platform/architecture. This is the same logic that
# setup.py uses to select which extensions should be built and should
# be kept in sync.
#
# The `default` target makes direct use of cmake easier since knowledge
# of which extensions are supported has been factored in, e.g.
#
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
# cmake --build . --target default
#
add_custom_target(default)
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling C extension.")
add_dependencies(default _C)
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C)
# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
# there are supported target arches.
if (VLLM_PUNICA_GPU_ARCHES AND
(ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS))
message(STATUS "Enabling punica extension.")
add_dependencies(default _punica_C)
endif()
endif()

View File

@ -38,6 +38,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \
# copy input files
COPY csrc csrc
COPY setup.py setup.py
COPY cmake cmake
COPY CMakeLists.txt CMakeLists.txt
COPY requirements.txt requirements.txt
COPY pyproject.toml pyproject.toml
COPY vllm/__init__.py vllm/__init__.py

View File

@ -1,4 +1,6 @@
include LICENSE
include requirements.txt
include CMakeLists.txt
recursive-include cmake *
recursive-include csrc *

73
cmake/hipify.py Executable file
View File

@ -0,0 +1,73 @@
#!/usr/bin/env python3
#
# A command line tool for running pytorch's hipify preprocessor on CUDA
# source files.
#
# See https://github.com/ROCm/hipify_torch
# and <torch install dir>/utils/hipify/hipify_python.py
#
import argparse
import shutil
import os
from torch.utils.hipify.hipify_python import hipify
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Project directory where all the source + include files live.
parser.add_argument(
"-p",
"--project_dir",
help="The project directory.",
)
# Directory where hipified files are written.
parser.add_argument(
"-o",
"--output_dir",
help="The output directory.",
)
# Source files to convert.
parser.add_argument("sources",
help="Source files to hipify.",
nargs="*",
default=[])
args = parser.parse_args()
# Limit include scope to project_dir only
includes = [os.path.join(args.project_dir, '*')]
# Get absolute path for all source files.
extra_files = [os.path.abspath(s) for s in args.sources]
# Copy sources from project directory to output directory.
# The directory might already exist to hold object files so we ignore that.
shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
hipify_result = hipify(project_directory=args.project_dir,
output_directory=args.output_dir,
header_include_dirs=[],
includes=includes,
extra_files=extra_files,
show_detailed=True,
is_pytorch_extension=True,
hipify_extra_files_only=True)
hipified_sources = []
for source in args.sources:
s_abs = os.path.abspath(source)
hipified_s_abs = (hipify_result[s_abs].hipified_path if
(s_abs in hipify_result
and hipify_result[s_abs].hipified_path is not None)
else s_abs)
hipified_sources.append(hipified_s_abs)
assert (len(hipified_sources) == len(args.sources))
# Print hipified source files.
print("\n".join(hipified_sources))

334
cmake/utils.cmake Normal file
View File

@ -0,0 +1,334 @@
#
# Attempt to find the python package that uses the same python executable as
# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
#
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
set(Python_EXECUTABLE ${EXECUTABLE})
find_package(Python COMPONENTS Interpreter Development.Module)
if (NOT Python_FOUND)
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
endif()
set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}")
set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
message(FATAL_ERROR
"Python version (${_VER}) is not one of the supported versions: "
"${_SUPPORTED_VERSIONS_LIST}.")
endif()
message(STATUS "Found python matching: ${EXECUTABLE}.")
endmacro()
#
# Run `EXPR` in python. The standard output of python is stored in `OUT` and
# has trailing whitespace stripped. If an error is encountered when running
# python, a fatal message `ERR_MSG` is issued.
#
function (run_python OUT EXPR ERR_MSG)
execute_process(
COMMAND
"${Python_EXECUTABLE}" "-c" "${EXPR}"
OUTPUT_VARIABLE PYTHON_OUT
RESULT_VARIABLE PYTHON_ERROR_CODE
ERROR_VARIABLE PYTHON_STDERR
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(NOT PYTHON_ERROR_CODE EQUAL 0)
message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
endif()
set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
endfunction()
# Run `EXPR` in python after importing `PKG`. Use the result of this to extend
# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
macro (append_cmake_prefix_path PKG EXPR)
run_python(_PREFIX_PATH
"import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
endmacro()
#
# Add a target named `hipify${NAME}` that runs the hipify preprocessor on a set
# of CUDA source files. The names of the corresponding "hipified" sources are
# stored in `OUT_SRCS`.
#
function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
#
# Split into C++ and non-C++ (i.e. CUDA) sources.
#
set(SRCS ${ORIG_SRCS})
set(CXX_SRCS ${ORIG_SRCS})
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$")
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$")
#
# Generate ROCm/HIP source file names from CUDA file names.
# Since HIP files are generated code, they will appear in the build area
# `CMAKE_CURRENT_BINARY_DIR` directory rather than the original csrc dir.
#
set(HIP_SRCS)
foreach (SRC ${SRCS})
string(REGEX REPLACE "\.cu$" "\.hip" SRC ${SRC})
string(REGEX REPLACE "cuda" "hip" SRC ${SRC})
list(APPEND HIP_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${SRC}")
endforeach()
set(CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/csrc)
add_custom_target(
hipify${NAME}
COMMAND ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS}
DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
BYPRODUCTS ${HIP_SRCS}
COMMENT "Running hipify on ${NAME} extension source files.")
# Swap out original extension sources with hipified sources.
list(APPEND HIP_SRCS ${CXX_SRCS})
set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
endfunction()
#
# Get additional GPU compiler flags from torch.
#
function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
if (${GPU_LANG} STREQUAL "CUDA")
#
# Get common NVCC flags from torch.
#
run_python(GPU_FLAGS
"from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))"
"Failed to determine torch nvcc compiler flags")
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
endif()
elseif(${GPU_LANG} STREQUAL "HIP")
#
# Get common HIP/HIPCC flags from torch.
#
run_python(GPU_FLAGS
"import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
"Failed to determine torch nvcc compiler flags")
list(APPEND GPU_FLAGS
"-DUSE_ROCM"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc")
endif()
set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
endfunction()
# Macro for converting a `gencode` version number to a cmake version number.
macro(string_to_ver OUT_VER IN_STR)
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
endmacro()
#
# Override the GPU architectures detected by cmake/torch and filter them by
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
# `GPU_ARCHES`.
#
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
#
macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
set(_GPU_SUPPORTED_ARCHES_LIST ${GPU_SUPPORTED_ARCHES} ${ARGN})
message(STATUS "${GPU_LANG} supported arches: ${_GPU_SUPPORTED_ARCHES_LIST}")
if (${GPU_LANG} STREQUAL "HIP")
#
# `GPU_ARCHES` controls the `--offload-arch` flags.
# `CMAKE_HIP_ARCHITECTURES` is set up by torch and can be controlled
# via the `PYTORCH_ROCM_ARCH` env variable.
#
#
# Find the intersection of the supported + detected architectures to
# set the module architecture flags.
#
set(${GPU_ARCHES})
foreach (_ARCH ${CMAKE_HIP_ARCHITECTURES})
if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
list(APPEND ${GPU_ARCHES} ${_ARCH})
endif()
endforeach()
if(NOT ${GPU_ARCHES})
message(FATAL_ERROR
"None of the detected ROCm architectures: ${CMAKE_HIP_ARCHITECTURES} is"
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
endif()
elseif(${GPU_LANG} STREQUAL "CUDA")
#
# Setup/process CUDA arch flags.
#
# The torch cmake setup hardcodes the detected architecture flags in
# `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
# can't modified on a per-target basis, e.g. for the `punica` extension.
# So, all the `-gencode` flags need to be extracted and removed from
# `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
# Since it's not possible to use `target_compiler_options` for adding target
# specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property
# must be used instead. This requires repackaging the architecture flags
# into a format that cmake expects for `CUDA_ARCHITECTURES`.
#
# This is a bit fragile in that it depends on torch using `-gencode` as opposed
# to one of the other nvcc options to specify architectures.
#
# Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override
# detected architectures.
#
message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS
${CMAKE_CUDA_FLAGS})
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
# and passed back via the `CUDA_ARCHITECTURES` property.
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
${CMAKE_CUDA_FLAGS})
# If this error is triggered, it might mean that torch has changed how it sets
# up nvcc architecture code generation flags.
if (NOT _CUDA_ARCH_FLAGS)
message(FATAL_ERROR
"Could not find any architecture related code generation flags in "
"CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})")
endif()
message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}")
# Initialize the architecture lists to empty.
set(${GPU_ARCHES})
# Process each `gencode` flag.
foreach(_ARCH ${_CUDA_ARCH_FLAGS})
# For each flag, extract the version number and whether it refers to PTX
# or native code.
# Note: if a regex matches then `CMAKE_MATCH_1` holds the binding
# for that match.
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
if (_COMPUTE)
set(_COMPUTE ${CMAKE_MATCH_1})
endif()
string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH})
if (_SM)
set(_SM ${CMAKE_MATCH_1})
endif()
string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH})
if (_CODE)
set(_CODE ${CMAKE_MATCH_1})
endif()
# Make sure the virtual architecture can be matched.
if (NOT _COMPUTE)
message(FATAL_ERROR
"Could not determine virtual architecture from: ${_ARCH}.")
endif()
# One of sm_ or compute_ must exist.
if ((NOT _SM) AND (NOT _CODE))
message(FATAL_ERROR
"Could not determine a codegen architecture from: ${_ARCH}.")
endif()
if (_SM)
set(_VIRT "")
set(_CODE_ARCH ${_SM})
else()
set(_VIRT "-virtual")
set(_CODE_ARCH ${_CODE})
endif()
# Check if the current version is in the supported arch list.
string_to_ver(_CODE_VER ${_CODE_ARCH})
if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
message(STATUS "discarding unsupported CUDA arch ${_VER}.")
continue()
endif()
# Add it to the arch list.
list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}")
endforeach()
endif()
message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}")
endmacro()
#
# Define a target named `GPU_MOD_NAME` for a single extension. The
# arguments are:
#
# DESTINATION <dest> - Module destination directory.
# LANGUAGE <lang> - The GPU language for this module, e.g CUDA, HIP,
# etc.
# SOURCES <sources> - List of source files relative to CMakeLists.txt
# directory.
#
# Optional arguments:
#
# ARCHITECTURES <arches> - A list of target GPU architectures in cmake
# format.
# Refer `CMAKE_CUDA_ARCHITECTURES` documentation
# and `CMAKE_HIP_ARCHITECTURES` for more info.
# ARCHITECTURES will use cmake's defaults if
# not provided.
# COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
# LINK_LIBRARIES <libraries> - Extra link libraries.
# WITH_SOABI - Generate library with python SOABI suffix name.
#
# Note: optimization level/debug info is set via cmake build type.
#
function (define_gpu_extension_target GPU_MOD_NAME)
cmake_parse_arguments(PARSE_ARGV 1
GPU
"WITH_SOABI"
"DESTINATION;LANGUAGE"
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
# Add hipify preprocessing step when building with HIP/ROCm.
if (GPU_LANGUAGE STREQUAL "HIP")
hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}")
endif()
if (GPU_WITH_SOABI)
set(GPU_WITH_SOABI WITH_SOABI)
else()
set(GPU_WITH_SOABI)
endif()
Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI})
if (GPU_LANGUAGE STREQUAL "HIP")
# Make this target dependent on the hipify preprocessor step.
add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
endif()
if (GPU_ARCHITECTURES)
set_target_properties(${GPU_MOD_NAME} PROPERTIES
${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
endif()
set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
target_compile_options(${GPU_MOD_NAME} PRIVATE
$<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>)
target_compile_definitions(${GPU_MOD_NAME} PRIVATE
"-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
${GPU_INCLUDE_DIRECTORIES})
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES}
${GPU_LIBRARIES})
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION})
endfunction()

View File

@ -1,6 +1,7 @@
[build-system]
# Should be mirrored in requirements-build.txt
requires = [
"cmake>=3.21",
"ninja",
"packaging",
"setuptools >= 49.4.0",

View File

@ -1,4 +1,5 @@
# Should be mirrored in pyproject.toml
cmake>=3.21
ninja
packaging
setuptools>=49.4.0

View File

@ -1,3 +1,4 @@
cmake>=3.21
ninja # For faster builds.
typing-extensions>=4.8.0
starlette

View File

@ -1,3 +1,4 @@
cmake>=3.21
ninja # For faster builds.
psutil
ray >= 2.9

474
setup.py
View File

@ -1,23 +1,16 @@
import contextlib
import io
import os
import re
import subprocess
import warnings
from pathlib import Path
from typing import List, Set
import sys
from typing import List
from packaging.version import parse, Version
import setuptools
import sys
from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext
from shutil import which
import torch
import torch.utils.cpp_extension as torch_cpp_ext
from torch.utils.cpp_extension import (
BuildExtension,
CUDAExtension,
CUDA_HOME,
ROCM_HOME,
)
from torch.utils.cpp_extension import CUDA_HOME
ROOT_DIR = os.path.dirname(__file__)
@ -25,17 +18,153 @@ ROOT_DIR = os.path.dirname(__file__)
assert sys.platform.startswith(
"linux"), "vLLM only supports Linux platform (including WSL)."
# If you are developing the C++ backend of vLLM, consider building vLLM with
# `python setup.py develop` since it will give you incremental builds.
# The downside is that this method is deprecated, see
# https://github.com/pypa/setuptools/issues/917
MAIN_CUDA_VERSION = "12.1"
# Supported NVIDIA GPU architectures.
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
ROCM_SUPPORTED_ARCHS = {"gfx908", "gfx90a", "gfx942", "gfx1100"}
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
def is_sccache_available() -> bool:
return which("sccache") is not None
def is_ccache_available() -> bool:
return which("ccache") is not None
def is_ninja_available() -> bool:
return which("ninja") is not None
def remove_prefix(text, prefix):
if text.startswith(prefix):
return text[len(prefix):]
return text
class CMakeExtension(Extension):
def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:
super().__init__(name, sources=[], **kwa)
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
class cmake_build_ext(build_ext):
# A dict of extension directories that have been configured.
did_config = {}
#
# Determine number of compilation jobs and optionally nvcc compile threads.
#
def compute_num_jobs(self):
try:
# os.sched_getaffinity() isn't universally available, so fall back
# to os.cpu_count() if we get an error here.
num_jobs = len(os.sched_getaffinity(0))
except AttributeError:
num_jobs = os.cpu_count()
nvcc_cuda_version = get_nvcc_cuda_version()
if nvcc_cuda_version >= Version("11.2"):
nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
num_jobs = max(1, round(num_jobs / (nvcc_threads / 4)))
else:
nvcc_threads = None
return num_jobs, nvcc_threads
#
# Perform cmake configuration for a single extension.
#
def configure(self, ext: CMakeExtension) -> None:
# If we've already configured using the CMakeLists.txt for
# this extension, exit early.
if ext.cmake_lists_dir in cmake_build_ext.did_config:
return
cmake_build_ext.did_config[ext.cmake_lists_dir] = True
# Select the build type.
# Note: optimization level + debug info are set by the build type
default_cfg = "Debug" if self.debug else "RelWithDebInfo"
cfg = os.getenv("CMAKE_BUILD_TYPE", default_cfg)
# where .so files will be written, should be the same for all extensions
# that use the same CMakeLists.txt.
outdir = os.path.abspath(
os.path.dirname(self.get_ext_fullpath(ext.name)))
cmake_args = [
'-DCMAKE_BUILD_TYPE={}'.format(cfg),
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'.format(outdir),
'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY={}'.format(self.build_temp),
]
verbose = bool(int(os.getenv('VERBOSE', '0')))
if verbose:
cmake_args += ['-DCMAKE_VERBOSE_MAKEFILE=ON']
if is_sccache_available():
cmake_args += [
'-DCMAKE_CXX_COMPILER_LAUNCHER=sccache',
'-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache',
]
elif is_ccache_available():
cmake_args += [
'-DCMAKE_CXX_COMPILER_LAUNCHER=ccache',
'-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache',
]
# Pass the python executable to cmake so it can find an exact
# match.
cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)]
if _install_punica():
cmake_args += ['-DVLLM_INSTALL_PUNICA_KERNELS=ON']
#
# Setup parallelism and build tool
#
num_jobs, nvcc_threads = self.compute_num_jobs()
if nvcc_threads:
cmake_args += ['-DNVCC_THREADS={}'.format(nvcc_threads)]
if is_ninja_available():
build_tool = ['-G', 'Ninja']
cmake_args += [
'-DCMAKE_JOB_POOL_COMPILE:STRING=compile',
'-DCMAKE_JOB_POOLS:STRING=compile={}'.format(num_jobs),
]
else:
# Default build tool to whatever cmake picks.
build_tool = []
subprocess.check_call(
['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
cwd=self.build_temp)
def build_extensions(self) -> None:
# Ensure that CMake is present and working
try:
subprocess.check_output(['cmake', '--version'])
except OSError as e:
raise RuntimeError('Cannot find CMake executable') from e
# Create build directory if it does not exist.
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
# Build all the extensions
for ext in self.extensions:
self.configure(ext)
ext_target_name = remove_prefix(ext.name, "vllm.")
num_jobs, _ = self.compute_num_jobs()
build_args = [
'--build', '.', '--target', ext_target_name, '-j',
str(num_jobs)
]
subprocess.check_call(['cmake', *build_args], cwd=self.build_temp)
def _is_cuda() -> bool:
@ -55,26 +184,8 @@ def _is_neuron() -> bool:
return torch_neuronx_installed
# Compiler flags.
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
# TODO(woosuk): Should we use -O3?
NVCC_FLAGS = ["-O2", "-std=c++17"]
if _is_hip():
if ROCM_HOME is None:
raise RuntimeError("Cannot find ROCM_HOME. "
"ROCm must be available to build the package.")
NVCC_FLAGS += ["-DUSE_ROCM"]
NVCC_FLAGS += ["-U__HIP_NO_HALF_CONVERSIONS__"]
NVCC_FLAGS += ["-U__HIP_NO_HALF_OPERATORS__"]
if _is_cuda() and CUDA_HOME is None:
raise RuntimeError(
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
def _install_punica() -> bool:
return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0")))
def get_hipcc_rocm_version():
@ -99,11 +210,6 @@ def get_hipcc_rocm_version():
return None
def glob(pattern: str):
root = Path(__name__).parent
return [str(p) for p in root.glob(pattern)]
def get_neuronxcc_version():
import sysconfig
site_dir = sysconfig.get_paths()["purelib"]
@ -123,12 +229,12 @@ def get_neuronxcc_version():
raise RuntimeError("Could not find HIP version in the output")
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
def get_nvcc_cuda_version() -> Version:
"""Get the CUDA version from nvcc.
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
"""
nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
universal_newlines=True)
output = nvcc_output.split()
release_idx = output.index("release") + 1
@ -136,250 +242,6 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
return nvcc_cuda_version
def get_pytorch_rocm_arch() -> Set[str]:
"""Get the cross section of Pytorch,and vllm supported gfx arches
ROCM can get the supported gfx architectures in one of two ways
Either through the PYTORCH_ROCM_ARCH env var, or output from
rocm_agent_enumerator.
In either case we can generate a list of supported arch's and
cross reference with VLLM's own ROCM_SUPPORTED_ARCHs.
"""
env_arch_list = os.environ.get("PYTORCH_ROCM_ARCH", None)
# If we don't have PYTORCH_ROCM_ARCH specified pull the list from
# rocm_agent_enumerator
if env_arch_list is None:
command = "rocm_agent_enumerator"
env_arch_list = (subprocess.check_output(
[command]).decode('utf-8').strip().replace("\n", ";"))
arch_source_str = "rocm_agent_enumerator"
else:
arch_source_str = "PYTORCH_ROCM_ARCH env variable"
# List are separated by ; or space.
pytorch_rocm_arch = set(env_arch_list.replace(" ", ";").split(";"))
# Filter out the invalid architectures and print a warning.
arch_list = pytorch_rocm_arch.intersection(ROCM_SUPPORTED_ARCHS)
# If none of the specified architectures are valid, raise an error.
if not arch_list:
raise RuntimeError(
f"None of the ROCM architectures in {arch_source_str} "
f"({env_arch_list}) is supported. "
f"Supported ROCM architectures are: {ROCM_SUPPORTED_ARCHS}.")
invalid_arch_list = pytorch_rocm_arch - ROCM_SUPPORTED_ARCHS
if invalid_arch_list:
warnings.warn(
f"Unsupported ROCM architectures ({invalid_arch_list}) are "
f"excluded from the {arch_source_str} output "
f"({env_arch_list}). Supported ROCM architectures are: "
f"{ROCM_SUPPORTED_ARCHS}.",
stacklevel=2)
return arch_list
def get_torch_arch_list() -> Set[str]:
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
# compiler to additionally include PTX code that can be runtime-compiled
# and executed on the 8.6 or newer architectures. While the PTX code will
# not give the best performance on the newer architectures, it provides
# forward compatibility.
env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
if env_arch_list is None:
return set()
# List are separated by ; or space.
torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
if not torch_arch_list:
return set()
# Filter out the invalid architectures and print a warning.
valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
{s + "+PTX"
for s in NVIDIA_SUPPORTED_ARCHS})
arch_list = torch_arch_list.intersection(valid_archs)
# If none of the specified architectures are valid, raise an error.
if not arch_list:
raise RuntimeError(
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
f"variable ({env_arch_list}) is supported. "
f"Supported CUDA architectures are: {valid_archs}.")
invalid_arch_list = torch_arch_list - valid_archs
if invalid_arch_list:
warnings.warn(
f"Unsupported CUDA architectures ({invalid_arch_list}) are "
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
f"({env_arch_list}). Supported CUDA architectures are: "
f"{valid_archs}.",
stacklevel=2)
return arch_list
if _is_hip():
rocm_arches = get_pytorch_rocm_arch()
NVCC_FLAGS += ["--offload-arch=" + arch for arch in rocm_arches]
else:
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
if _is_cuda() and not compute_capabilities:
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# GPUs on the current machine.
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 7:
raise RuntimeError(
"GPUs with compute capability below 7.0 are not supported.")
compute_capabilities.add(f"{major}.{minor}")
ext_modules = []
if _is_cuda():
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if not compute_capabilities:
# If no GPU is specified nor available, add all supported architectures
# based on the NVCC CUDA version.
compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy()
if nvcc_cuda_version < Version("11.1"):
compute_capabilities.remove("8.6")
if nvcc_cuda_version < Version("11.8"):
compute_capabilities.remove("8.9")
compute_capabilities.remove("9.0")
# Validate the NVCC CUDA version.
if nvcc_cuda_version < Version("11.0"):
raise RuntimeError(
"CUDA 11.0 or higher is required to build the package.")
if (nvcc_cuda_version < Version("11.1")
and any(cc.startswith("8.6") for cc in compute_capabilities)):
raise RuntimeError(
"CUDA 11.1 or higher is required for compute capability 8.6.")
if nvcc_cuda_version < Version("11.8"):
if any(cc.startswith("8.9") for cc in compute_capabilities):
# CUDA 11.8 is required to generate the code targeting compute
# capability 8.9. However, GPUs with compute capability 8.9 can
# also run the code generated by the previous versions of CUDA 11
# and targeting compute capability 8.0. Therefore, if CUDA 11.8
# is not available, we target compute capability 8.0 instead of 8.9.
warnings.warn(
"CUDA 11.8 or higher is required for compute capability 8.9. "
"Targeting compute capability 8.0 instead.",
stacklevel=2)
compute_capabilities = set(cc for cc in compute_capabilities
if not cc.startswith("8.9"))
compute_capabilities.add("8.0+PTX")
if any(cc.startswith("9.0") for cc in compute_capabilities):
raise RuntimeError(
"CUDA 11.8 or higher is required for compute capability 9.0.")
NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy()
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
num = capability[0] + capability[2]
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
if capability.endswith("+PTX"):
NVCC_FLAGS += [
"-gencode", f"arch=compute_{num},code=compute_{num}"
]
if int(capability[0]) >= 8:
NVCC_FLAGS_PUNICA += [
"-gencode", f"arch=compute_{num},code=sm_{num}"
]
if capability.endswith("+PTX"):
NVCC_FLAGS_PUNICA += [
"-gencode", f"arch=compute_{num},code=compute_{num}"
]
# Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"):
nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
num_threads = min(os.cpu_count(), nvcc_threads)
NVCC_FLAGS += ["--threads", str(num_threads)]
if nvcc_cuda_version >= Version("11.8"):
NVCC_FLAGS += ["-DENABLE_FP8_E5M2"]
# changes for punica kernels
NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS
REMOVE_NVCC_FLAGS = [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
]
for flag in REMOVE_NVCC_FLAGS:
with contextlib.suppress(ValueError):
torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag)
install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0")))
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 8:
install_punica = False
break
if install_punica:
ext_modules.append(
CUDAExtension(
name="vllm._punica_C",
sources=["csrc/punica/punica_ops.cc"] +
glob("csrc/punica/bgmv/*.cu"),
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS_PUNICA,
},
))
elif _is_neuron():
neuronxcc_version = get_neuronxcc_version()
vllm_extension_sources = [
"csrc/cache_kernels.cu",
"csrc/attention/attention_kernels.cu",
"csrc/pos_encoding_kernels.cu",
"csrc/activation_kernels.cu",
"csrc/layernorm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/quantization/gptq/q_gemm.cu",
"csrc/cuda_utils_kernels.cu",
"csrc/moe_align_block_size_kernels.cu",
"csrc/pybind.cpp",
]
if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
vllm_extension_sources.append(
"csrc/quantization/marlin/marlin_cuda_kernel.cu")
vllm_extension_sources.append("csrc/custom_all_reduce.cu")
# Add MoE kernels.
ext_modules.append(
CUDAExtension(
name="vllm._moe_C",
sources=glob("csrc/moe/*.cu") + glob("csrc/moe/*.cpp"),
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
))
if not _is_neuron():
vllm_extension = CUDAExtension(
name="vllm._C",
sources=vllm_extension_sources,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
libraries=["cuda"] if _is_cuda() else [],
)
ext_modules.append(vllm_extension)
def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)
@ -401,7 +263,7 @@ def get_vllm_version() -> str:
version = find_version(get_path("vllm", "__init__.py"))
if _is_cuda():
cuda_version = str(nvcc_cuda_version)
cuda_version = str(get_nvcc_cuda_version())
if cuda_version != MAIN_CUDA_VERSION:
cuda_version_str = cuda_version.replace(".", "")[:3]
version += f"+cu{cuda_version_str}"
@ -413,7 +275,7 @@ def get_vllm_version() -> str:
version += f"+rocm{rocm_version_str}"
elif _is_neuron():
# Get the Neuron version
neuron_version = str(neuronxcc_version)
neuron_version = str(get_neuronxcc_version())
if neuron_version != MAIN_CUDA_VERSION:
neuron_version_str = neuron_version.replace(".", "")[:3]
version += f"+neuron{neuron_version_str}"
@ -437,7 +299,7 @@ def get_requirements() -> List[str]:
if _is_cuda():
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
if nvcc_cuda_version <= Version("11.8"):
if get_nvcc_cuda_version() <= Version("11.8"):
# replace cupy-cuda12x with cupy-cuda11x for cuda 11.x
for i in range(len(requirements)):
if requirements[i].startswith("cupy-cuda12x"):
@ -456,14 +318,24 @@ def get_requirements() -> List[str]:
return requirements
ext_modules = []
if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
if _install_punica():
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
if not _is_neuron():
ext_modules.append(CMakeExtension(name="vllm._C"))
package_data = {
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
}
if os.environ.get("VLLM_USE_PRECOMPILED"):
ext_modules = []
package_data["vllm"].append("*.so")
setuptools.setup(
setup(
name="vllm",
version=get_vllm_version(),
author="vLLM Team",
@ -485,11 +357,11 @@ setuptools.setup(
"License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs",
"examples", "tests")),
packages=find_packages(exclude=("benchmarks", "csrc", "docs", "examples",
"tests")),
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {},
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
package_data=package_data,
)