Avoid compiling kernels for double data type (#933)

This commit is contained in:
Woosuk Kwon 2023-09-02 14:59:47 +09:00 committed by GitHub
parent 32b6816e55
commit 8ce9c50d40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 28 additions and 21 deletions

View File

@ -1,6 +1,8 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
namespace vllm {
template<typename T>
@ -34,9 +36,7 @@ void silu_and_mul(
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"silu_and_mul_kernel",
[&] {
@ -71,9 +71,7 @@ __global__ void activation_kernel(
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
AT_DISPATCH_FLOATING_TYPES_AND2( \
at::ScalarType::Half, \
at::ScalarType::BFloat16, \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"activation_kernel", \
[&] { \

View File

@ -1,6 +1,8 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
#include <algorithm>
#include <cassert>
#include <map>
@ -125,9 +127,7 @@ void copy_blocks(
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, numel_per_block));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
@ -202,9 +202,7 @@ void reshape_and_cache(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
"reshape_and_cache_kernel",
[&] {
@ -364,9 +362,7 @@ void gather_cached_kv(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
"gather_cached_kv_kernel_optimized",
[&] {

14
csrc/dispatch_utils.h Normal file
View File

@ -0,0 +1,14 @@
/*
* Adapted from
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
*/
#include <torch/extension.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

View File

@ -1,6 +1,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
#include "reduction_utils.cuh"
namespace vllm {
@ -46,9 +47,7 @@ void rms_norm(
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"rms_norm_kernel",
[&] {

View File

@ -1,6 +1,8 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
namespace vllm {
template<typename scalar_t>
@ -83,9 +85,7 @@ void rotary_embedding_neox(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
"rotary_embedding_neox",
[&] {