[ROCm] Add support for Punica kernels on AMD GPUs (#3140)

Co-authored-by: miloice <jeffaw99@hotmail.com>
This commit is contained in:
kliuae 2024-05-10 00:19:50 +08:00 committed by GitHub
parent 0ee535b294
commit ff5abcd746
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 287 additions and 26 deletions

View File

@ -219,7 +219,8 @@ set(VLLM_PUNICA_EXT_SRC
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
"csrc/punica/punica_ops.cc")
"csrc/punica/punica_ops.cu"
"csrc/punica/punica_pybind.cpp")
#
# Copy GPU compilation flags+update for punica
@ -243,6 +244,9 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA")
endif()
endforeach()
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
elseif(${VLLM_GPU_LANG} STREQUAL "HIP")
set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES})
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
endif()
if (VLLM_PUNICA_GPU_ARCHES)
@ -277,11 +281,6 @@ add_custom_target(default)
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling C extension.")
add_dependencies(default _C)
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C)
# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
@ -292,3 +291,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
add_dependencies(default _punica_C)
endif()
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C)
endif()

View File

@ -94,6 +94,9 @@ COPY . .
RUN python3 -m pip install --upgrade pip numba
# make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \

View File

@ -28,6 +28,12 @@
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#endif
#ifndef USE_ROCM
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta)
#else
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
#endif
#ifndef USE_ROCM
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)

View File

@ -1,8 +1,14 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include <cooperative_groups.h>
#else
#include <hip/hip_cooperative_groups.h>
#endif
#ifndef USE_ROCM
#include <cuda/pipeline>
#endif
#include <cuda_runtime.h>
#include <iostream>
#include <stdio.h>
@ -11,6 +17,24 @@
namespace cg = cooperative_groups;
#ifdef USE_ROCM
template <size_t len>
__host__ __device__
inline void* memcpy_blocking(void *dst, const void *src) {
// Does not handle the case of long datatypes
char *d = reinterpret_cast<char *>(dst);
const char *s = reinterpret_cast<const char *>(src);
size_t i = 0;
#pragma unroll
for (i = 0; i < len; ++i) {
d[i] = s[i];
}
return dst;
}
#endif
#ifndef USE_ROCM
// nthrs = (32, 4)
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
@ -141,6 +165,81 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
}
}
#else
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
size_t j = blockIdx.x;
constexpr size_t tile_size = tx * ty * vec_size;
constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size;
__shared__ float y_warpwise[ty];
float y = 0;
vec_t<in_T, vec_size> x_vec;
vec_t<W_T, vec_size> w_vec;
size_t tile_idx;
#pragma unroll
for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
x_vec.load(X + (batch_idx * feat_in) +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size);
}
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += VLLM_SHFL_DOWN_SYNC(sum, offset);
}
__syncthreads();
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
y += sum;
}
}
if (threadIdx.x == 0) {
y_warpwise[threadIdx.y] = y;
}
__syncthreads();
float y_write = 0.f;
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y_write += y_warpwise[i];
}
// write Y;
if (threadIdx.x == 0 && threadIdx.y == 0) {
size_t y_idx = batch_idx * full_y_size + y_offset + j;
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(y_write));
}
}
#endif
// nthrs = (2, 16, 4)
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
typename in_T, typename out_T, typename W_T>
@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
#ifndef USE_ROCM
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
#else
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
#endif
}
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
sum = g.shfl(sum, 0);
if (threadIdx.x == 0) {
#ifndef USE_ROCM
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
#else
size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y;
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(sum));
#endif
}
}
@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
scale);
}
} else {
#ifndef USE_ROCM
static_assert(feat_in % (vec_size * 32) == 0 ||
feat_in % (vec_size * 16) == 0 ||
feat_in % (vec_size * 8) == 0);
@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
full_y_size, num_layers, layer_idx,
scale);
}
#else
constexpr size_t rocm_warp_size = warpSize;
#define CHECK_INPUT_TILEABLE_BY(vec_size_) \
feat_in % (rocm_warp_size * vec_size_) == 0
#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \
if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \
constexpr size_t vec_size_shrink = vec_size_; \
constexpr int tx = tx_; \
constexpr int ty = ty_; \
dim3 nblks(feat_out, batch_size); \
dim3 nthrs(tx, ty); \
bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink, \
vec_size_shrink * sizeof(in_T), \
vec_size_shrink * sizeof(W_T), \
tx, ty, tz> \
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, \
full_y_size, num_layers, layer_idx, \
scale); \
}
static_assert(CHECK_INPUT_TILEABLE_BY(32) ||
CHECK_INPUT_TILEABLE_BY(16) ||
CHECK_INPUT_TILEABLE_BY( 8) ||
CHECK_INPUT_TILEABLE_BY( 4) ||
CHECK_INPUT_TILEABLE_BY( 2) ||
CHECK_INPUT_TILEABLE_BY( 1));
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1)
#undef CHECK_INPUT_TILEABLE_BY
#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM
#endif
}
}

View File

@ -1,8 +1,6 @@
#ifndef VEC_DTYPES_CUH_
#define VEC_DTYPES_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifdef FLASHINFER_USE_FP8
#include <cuda_fp8.h>
#endif
@ -10,6 +8,9 @@
#include <type_traits>
#include "../type_convert.h"
#include "../../cuda_compat.h"
#define FLASHINFER_INLINE \
inline __attribute__((always_inline)) __device__ __host__

View File

@ -1,12 +1,11 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cstdint>
#include "type_convert.h"
#include "../cuda_compat.h"
#include "bgmv/bgmv_config.h"
namespace {
//====== utils ======
@ -568,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
}
} // namespace
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
"dispatch_bgmv_low_level");
}

11
csrc/punica/punica_ops.h Normal file
View File

@ -0,0 +1,11 @@
#pragma once
#include <torch/extension.h>
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, float scale);
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx,
float scale, int64_t h_in, int64_t h_out,
int64_t y_offset);

View File

@ -0,0 +1,13 @@
#include <torch/extension.h>
#include "punica_ops.h"
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
"dispatch_bgmv_low_level");
}

View File

@ -0,0 +1,82 @@
#ifndef CSRC__PUNICA__TYPE_CONVERT_H__
#define CSRC__PUNICA__TYPE_CONVERT_H__
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__
typedef __half nv_half;
typedef __hip_bfloat16 nv_bfloat16;
typedef __hip_bfloat162 nv_bfloat162;
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) {
return __hip_bfloat162{val, val};
}
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) {
return __hip_bfloat162{vall, valr};
}
template <typename T_src, typename T_dst>
__TYPE_CONVERT__HOST_DEVICE__
inline T_dst convert_type(T_src val) {
return static_cast<T_dst>(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline float convert_type<__half, float>(__half val) {
return __half2float(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __half convert_type<float, __half>(float val) {
return __float2half(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) {
return __bfloat162float(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat16 convert_type<float, __hip_bfloat16>(float val) {
return __float2bfloat16(val);
}
template <typename T>
__TYPE_CONVERT__HOST_DEVICE__
inline T vllm_add(T a, T b) {
return a + b;
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __half vllm_add<__half>(__half a, __half b) {
return __hadd(a, b);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) {
return __hadd(a, b);
}
#undef __TYPE_CONVERT__HOST_DEVICE__
#endif // USE_ROCM
#endif // CSRC__PUNICA__TYPE_CONVERT_H__

View File

@ -385,12 +385,12 @@ ext_modules = []
if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
if _install_punica():
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
if not _is_neuron():
ext_modules.append(CMakeExtension(name="vllm._C"))
if _install_punica():
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
package_data = {
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
}