mirror of https://github.com/vllm-project/vllm
Allocate more shared memory to attention kernel (#1154)
This commit is contained in:
parent
03ffd0a022
commit
cf5cb1e33e
|
@ -341,6 +341,9 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||
} // namespace vllm
|
||||
|
||||
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
||||
cudaFuncSetAttribute( \
|
||||
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
||||
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
|
@ -401,6 +404,8 @@ void single_query_cached_kv_attention_launcher(
|
|||
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int logits_size = padded_max_context_len * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
||||
// Keep that in sync with the logic here!
|
||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||
|
||||
dim3 grid(num_heads, num_seqs);
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"get_device_attribute",
|
||||
&get_device_attribute,
|
||||
"Gets the specified device attribute.");
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id)
|
||||
{
|
||||
int device, value;
|
||||
if (device_id < 0) {
|
||||
cudaGetDevice(&device);
|
||||
}
|
||||
else {
|
||||
device = device_id;
|
||||
}
|
||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
||||
return value;
|
||||
}
|
11
setup.py
11
setup.py
|
@ -195,6 +195,17 @@ quantization_extension = CUDAExtension(
|
|||
)
|
||||
ext_modules.append(quantization_extension)
|
||||
|
||||
# Misc. CUDA utils.
|
||||
cuda_utils_extension = CUDAExtension(
|
||||
name="vllm.cuda_utils",
|
||||
sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(cuda_utils_extension)
|
||||
|
||||
|
||||
def get_path(*filepath) -> str:
|
||||
return os.path.join(ROOT_DIR, *filepath)
|
||||
|
|
|
@ -7,8 +7,12 @@ from xformers import ops as xops
|
|||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm import attention_ops
|
||||
from vllm.utils import get_max_shared_memory_bytes
|
||||
|
||||
MAX_SEQ_LEN = 8192
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
# This will change depending on the compute capability.
|
||||
# - 512 as a buffer
|
||||
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
NUM_BLOCKS = 128 # Arbitrary values for testing
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
|
@ -135,6 +139,7 @@ def test_single_query_cached_kv_attention(
|
|||
device="cuda")
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
context_lens[-1] = MAX_SEQ_LEN
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||
|
||||
|
@ -243,6 +248,7 @@ def test_multi_query_kv_attention(
|
|||
torch.cuda.manual_seed(seed)
|
||||
|
||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
||||
seq_lens[-1] = MAX_SEQ_LEN
|
||||
num_tokens = sum(seq_lens)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import enum
|
||||
from platform import uname
|
||||
import uuid
|
||||
from platform import uname
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from vllm import cuda_utils
|
||||
|
||||
|
||||
class Device(enum.Enum):
|
||||
GPU = enum.auto()
|
||||
|
@ -25,6 +27,15 @@ class Counter:
|
|||
self.counter = 0
|
||||
|
||||
|
||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||
"""Returns the maximum shared memory per thread block in bytes."""
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 # pylint: disable=invalid-name
|
||||
max_shared_mem = cuda_utils.get_device_attribute(
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
|
||||
return int(max_shared_mem)
|
||||
|
||||
|
||||
def get_gpu_memory(gpu: int = 0) -> int:
|
||||
"""Returns the total memory of the GPU in bytes."""
|
||||
return torch.cuda.get_device_properties(gpu).total_memory
|
||||
|
|
|
@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.utils import get_gpu_memory
|
||||
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
|
||||
|
||||
|
||||
class Worker:
|
||||
|
@ -136,6 +136,10 @@ class Worker:
|
|||
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
||||
self.cache_config = cache_config
|
||||
self.block_size = cache_config.block_size
|
||||
|
||||
_check_if_can_support_max_seq_len(self.scheduler_config.max_model_len,
|
||||
self.block_size)
|
||||
|
||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config)
|
||||
self.cache_events = self.cache_engine.events
|
||||
|
@ -347,3 +351,23 @@ def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
|
|||
|
||||
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
|
||||
return x + [0] * (max_len - len(x))
|
||||
|
||||
|
||||
def _check_if_can_support_max_seq_len(max_seq_len: int,
|
||||
block_size: int) -> None:
|
||||
# Follows the logic in
|
||||
# attention_kernels.cu::single_query_cached_kv_attention_launcher
|
||||
max_shared_mem = get_max_shared_memory_bytes()
|
||||
float32_bytes = torch.finfo(torch.float).bits // 8
|
||||
padded_max_seq_len = (
|
||||
(max_seq_len + block_size - 1) / block_size) * block_size
|
||||
# padded_max_seq_len + extra buffer
|
||||
required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
|
||||
if padded_max_seq_len * float32_bytes > max_shared_mem:
|
||||
raise RuntimeError(
|
||||
f"vLLM cannot currently support max_model_len={max_seq_len} "
|
||||
f"with block_size={block_size} on GPU with compute "
|
||||
f"capability {torch.cuda.get_device_capability()} "
|
||||
f"(required shared memory {required_shared_mem} > "
|
||||
f"available shared memory {max_shared_mem}). "
|
||||
"This will be fixed in a future release.")
|
||||
|
|
Loading…
Reference in New Issue