mirror of https://github.com/vllm-project/vllm
[ROCm] Add support for Punica kernels on AMD GPUs (#3140)
Co-authored-by: miloice <jeffaw99@hotmail.com>
This commit is contained in:
parent
0ee535b294
commit
ff5abcd746
|
@ -219,7 +219,8 @@ set(VLLM_PUNICA_EXT_SRC
|
|||
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
|
||||
"csrc/punica/punica_ops.cc")
|
||||
"csrc/punica/punica_ops.cu"
|
||||
"csrc/punica/punica_pybind.cpp")
|
||||
|
||||
#
|
||||
# Copy GPU compilation flags+update for punica
|
||||
|
@ -243,6 +244,9 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA")
|
|||
endif()
|
||||
endforeach()
|
||||
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
||||
elseif(${VLLM_GPU_LANG} STREQUAL "HIP")
|
||||
set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES})
|
||||
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
||||
endif()
|
||||
|
||||
if (VLLM_PUNICA_GPU_ARCHES)
|
||||
|
@ -277,11 +281,6 @@ add_custom_target(default)
|
|||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
message(STATUS "Enabling C extension.")
|
||||
add_dependencies(default _C)
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Enabling moe extension.")
|
||||
add_dependencies(default _moe_C)
|
||||
|
||||
# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
|
||||
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
|
||||
|
@ -292,3 +291,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||
add_dependencies(default _punica_C)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Enabling moe extension.")
|
||||
add_dependencies(default _moe_C)
|
||||
endif()
|
||||
|
|
|
@ -94,6 +94,9 @@ COPY . .
|
|||
|
||||
RUN python3 -m pip install --upgrade pip numba
|
||||
|
||||
# make sure punica kernels are built (for LoRA)
|
||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -U -r requirements-rocm.txt \
|
||||
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
|
||||
|
|
|
@ -28,6 +28,12 @@
|
|||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta)
|
||||
#else
|
||||
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
|
|
|
@ -1,8 +1,14 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#ifndef USE_ROCM
|
||||
#include <cooperative_groups.h>
|
||||
#else
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#endif
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda/pipeline>
|
||||
#endif
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <stdio.h>
|
||||
|
@ -11,6 +17,24 @@
|
|||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
#ifdef USE_ROCM
|
||||
template <size_t len>
|
||||
__host__ __device__
|
||||
inline void* memcpy_blocking(void *dst, const void *src) {
|
||||
// Does not handle the case of long datatypes
|
||||
char *d = reinterpret_cast<char *>(dst);
|
||||
const char *s = reinterpret_cast<const char *>(src);
|
||||
size_t i = 0;
|
||||
#pragma unroll
|
||||
for (i = 0; i < len; ++i) {
|
||||
d[i] = s[i];
|
||||
}
|
||||
return dst;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
// nthrs = (32, 4)
|
||||
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||
|
@ -141,6 +165,81 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||
typename out_T, typename W_T>
|
||||
__global__ void
|
||||
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||
float scale) {
|
||||
size_t batch_idx = blockIdx.y;
|
||||
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||
if (idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t j = blockIdx.x;
|
||||
constexpr size_t tile_size = tx * ty * vec_size;
|
||||
constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size;
|
||||
__shared__ float y_warpwise[ty];
|
||||
|
||||
float y = 0;
|
||||
vec_t<in_T, vec_size> x_vec;
|
||||
vec_t<W_T, vec_size> w_vec;
|
||||
size_t tile_idx;
|
||||
|
||||
#pragma unroll
|
||||
for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
|
||||
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
|
||||
x_vec.load(X + (batch_idx * feat_in) +
|
||||
tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
w_vec.load(W + (idx * feat_out + j) * feat_in +
|
||||
tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
}
|
||||
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += VLLM_SHFL_DOWN_SYNC(sum, offset);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
|
||||
y += sum;
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
y_warpwise[threadIdx.y] = y;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float y_write = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < ty; ++i) {
|
||||
y_write += y_warpwise[i];
|
||||
}
|
||||
|
||||
// write Y;
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
size_t y_idx = batch_idx * full_y_size + y_offset + j;
|
||||
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(y_write));
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// nthrs = (2, 16, 4)
|
||||
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
|
||||
typename in_T, typename out_T, typename W_T>
|
||||
|
@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
#ifndef USE_ROCM
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
#else
|
||||
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
|
||||
#endif
|
||||
}
|
||||
|
||||
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
|
||||
|
@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||
sum = g.shfl(sum, 0);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
#ifndef USE_ROCM
|
||||
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
|
||||
#else
|
||||
size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||
threadIdx.z * ty + threadIdx.y;
|
||||
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(sum));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||
scale);
|
||||
}
|
||||
} else {
|
||||
#ifndef USE_ROCM
|
||||
static_assert(feat_in % (vec_size * 32) == 0 ||
|
||||
feat_in % (vec_size * 16) == 0 ||
|
||||
feat_in % (vec_size * 8) == 0);
|
||||
|
@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
}
|
||||
#else
|
||||
constexpr size_t rocm_warp_size = warpSize;
|
||||
|
||||
#define CHECK_INPUT_TILEABLE_BY(vec_size_) \
|
||||
feat_in % (rocm_warp_size * vec_size_) == 0
|
||||
|
||||
#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \
|
||||
if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \
|
||||
constexpr size_t vec_size_shrink = vec_size_; \
|
||||
constexpr int tx = tx_; \
|
||||
constexpr int ty = ty_; \
|
||||
dim3 nblks(feat_out, batch_size); \
|
||||
dim3 nthrs(tx, ty); \
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink, \
|
||||
vec_size_shrink * sizeof(in_T), \
|
||||
vec_size_shrink * sizeof(W_T), \
|
||||
tx, ty, tz> \
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, \
|
||||
full_y_size, num_layers, layer_idx, \
|
||||
scale); \
|
||||
}
|
||||
|
||||
static_assert(CHECK_INPUT_TILEABLE_BY(32) ||
|
||||
CHECK_INPUT_TILEABLE_BY(16) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 8) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 4) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 2) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 1));
|
||||
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1)
|
||||
|
||||
#undef CHECK_INPUT_TILEABLE_BY
|
||||
#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
#ifndef VEC_DTYPES_CUH_
|
||||
#define VEC_DTYPES_CUH_
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#ifdef FLASHINFER_USE_FP8
|
||||
#include <cuda_fp8.h>
|
||||
#endif
|
||||
|
@ -10,6 +8,9 @@
|
|||
|
||||
#include <type_traits>
|
||||
|
||||
#include "../type_convert.h"
|
||||
#include "../../cuda_compat.h"
|
||||
|
||||
#define FLASHINFER_INLINE \
|
||||
inline __attribute__((always_inline)) __device__ __host__
|
||||
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cstdint>
|
||||
|
||||
#include "type_convert.h"
|
||||
#include "../cuda_compat.h"
|
||||
#include "bgmv/bgmv_config.h"
|
||||
|
||||
namespace {
|
||||
|
||||
//====== utils ======
|
||||
|
||||
|
@ -568,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
|||
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//====== pybind ======
|
||||
|
||||
#define DEFINE_pybind(name) m.def(#name, &name, #name);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
|
||||
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
|
||||
"dispatch_bgmv_low_level");
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx, float scale);
|
||||
|
||||
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx,
|
||||
float scale, int64_t h_in, int64_t h_out,
|
||||
int64_t y_offset);
|
|
@ -0,0 +1,13 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
#include "punica_ops.h"
|
||||
|
||||
//====== pybind ======
|
||||
|
||||
#define DEFINE_pybind(name) m.def(#name, &name, #name);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
|
||||
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
|
||||
"dispatch_bgmv_low_level");
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
#ifndef CSRC__PUNICA__TYPE_CONVERT_H__
|
||||
#define CSRC__PUNICA__TYPE_CONVERT_H__
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#else
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__
|
||||
|
||||
typedef __half nv_half;
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
typedef __hip_bfloat162 nv_bfloat162;
|
||||
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) {
|
||||
return __hip_bfloat162{val, val};
|
||||
}
|
||||
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) {
|
||||
return __hip_bfloat162{vall, valr};
|
||||
}
|
||||
|
||||
template <typename T_src, typename T_dst>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline T_dst convert_type(T_src val) {
|
||||
return static_cast<T_dst>(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline float convert_type<__half, float>(__half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __half convert_type<float, __half>(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat16 convert_type<float, __hip_bfloat16>(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline T vllm_add(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __half vllm_add<__half>(__half a, __half b) {
|
||||
return __hadd(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) {
|
||||
return __hadd(a, b);
|
||||
}
|
||||
|
||||
#undef __TYPE_CONVERT__HOST_DEVICE__
|
||||
|
||||
#endif // USE_ROCM
|
||||
|
||||
#endif // CSRC__PUNICA__TYPE_CONVERT_H__
|
6
setup.py
6
setup.py
|
@ -385,12 +385,12 @@ ext_modules = []
|
|||
if _is_cuda():
|
||||
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
|
||||
|
||||
if _install_punica():
|
||||
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
|
||||
|
||||
if not _is_neuron():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
||||
if _install_punica():
|
||||
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
|
||||
|
||||
package_data = {
|
||||
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue