From 96853af5a830d42496aa6cfd5c670d073d6a0209 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 14 Jul 2023 20:06:40 -0400 Subject: [PATCH] Optimize MQA Kernel (#452) --- csrc/attention.cpp | 1 + csrc/attention/attention_kernels.cu | 31 +++++++--- vllm/config.py | 7 +++ vllm/model_executor/layers/attention.py | 43 +++++++++---- vllm/model_executor/models/gpt_bigcode.py | 74 +++++++---------------- 5 files changed, 84 insertions(+), 72 deletions(-) diff --git a/csrc/attention.cpp b/csrc/attention.cpp index b0ee4c906b..6be8a6d25a 100644 --- a/csrc/attention.cpp +++ b/csrc/attention.cpp @@ -6,6 +6,7 @@ void single_query_cached_kv_attention( torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& head_mapping, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 69ae88908c..9302204107 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -74,14 +74,17 @@ template< __global__ void single_query_cached_kv_attention_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride) { + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; @@ -91,6 +94,7 @@ __global__ void single_query_cached_kv_attention_kernel( const int head_idx = blockIdx.x; const int num_heads = gridDim.x; + const int kv_head_idx = head_mapping[head_idx]; const int seq_idx = blockIdx.y; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; @@ -158,8 +162,8 @@ __global__ void single_query_cached_kv_attention_kernel( #pragma unroll for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE - + head_idx * HEAD_SIZE * BLOCK_SIZE + const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; @@ -246,8 +250,8 @@ __global__ void single_query_cached_kv_attention_kernel( L_vec logits_vec; from_float(logits_vec, *reinterpret_cast(logits + token_idx)); - const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE - + head_idx * HEAD_SIZE * BLOCK_SIZE; + const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -328,12 +332,15 @@ __global__ void single_query_cached_kv_attention_kernel( query_ptr, \ key_cache_ptr, \ value_cache_ptr, \ + head_mapping_ptr, \ scale, \ block_tables_ptr, \ context_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ - query_stride); + q_stride, \ + kv_block_stride, \ + kv_head_stride); // TODO(woosuk): Tune NUM_THREADS. template< @@ -345,6 +352,7 @@ void single_query_cached_kv_attention_launcher( torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& head_mapping, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, @@ -354,7 +362,9 @@ void single_query_cached_kv_attention_launcher( int num_heads = query.size(1); int head_size = query.size(2); int max_num_blocks_per_seq = block_tables.size(1); - int query_stride = query.stride(0); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); assert(head_size % thread_group_size == 0); @@ -368,6 +378,7 @@ void single_query_cached_kv_attention_launcher( T* query_ptr = reinterpret_cast(query.data_ptr()); T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); @@ -422,6 +433,7 @@ void single_query_cached_kv_attention_launcher( query, \ key_cache, \ value_cache, \ + head_mapping, \ scale, \ block_tables, \ context_lens, \ @@ -469,6 +481,7 @@ void single_query_cached_kv_attention( torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& context_lens, // [num_seqs] diff --git a/vllm/config.py b/vllm/config.py index c34bf29536..9fed21f5c6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -94,6 +94,13 @@ class ModelConfig: return self.hf_config.hidden_size // self.hf_config.num_attention_heads def get_num_heads(self, parallel_config: "ParallelConfig") -> int: + # For GPTBigCode: + if getattr(self.hf_config, "multi_query", False): + # Multi-query attention, only one KV head. + return 1 + # For Falcon: + if getattr(self.hf_config, "n_head_kv", None) is not None: + return self.hf_config.n_head_kv total_num_attention_heads = self.hf_config.num_attention_heads return total_num_attention_heads // parallel_config.tensor_parallel_size diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index f4550e82eb..d94649cf1d 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -44,12 +44,23 @@ class PagedAttention(nn.Module): 5. Output a flattened 1D tensor. """ - def __init__(self, num_heads: int, head_size: int, scale: float) -> None: + def __init__(self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None) -> None: super().__init__() self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.attn_op = xops.fmha.cutlass.FwOp() + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.head_mapping = torch.repeat_interleave( + torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"), + self.num_queries_per_kv) if self.head_size not in _SUPPORTED_HEAD_SIZES: raise ValueError(f"head_size ({self.head_size}) is not supported. " @@ -76,10 +87,18 @@ class PagedAttention(nn.Module): Args: output: shape = [num_prompt_tokens, num_heads, head_size] query: shape = [num_prompt_tokens, num_heads, head_size] - key: shape = [num_prompt_tokens, num_heads, head_size] - value: shape = [num_prompt_tokens, num_heads, head_size] + key: shape = [num_prompt_tokens, num_kv_heads, head_size] + value: shape = [num_prompt_tokens, num_kv_heads, head_size] input_metadata: metadata for paged attention. """ + + if self.num_kv_heads != self.num_heads: + # Project the key and value tensors to the desired number of heads. + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) + value = torch.repeat_interleave(value, + self.num_queries_per_kv, + dim=1) + # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. out = xops.memory_efficient_attention_forward( query.unsqueeze(0), @@ -107,9 +126,9 @@ class PagedAttention(nn.Module): Args: output: shape = [num_generation_tokens, num_heads, head_size] query: shape = [num_generation_tokens, num_heads, head_size] - key_cache: shape = [num_blocks, num_heads, head_size/x, + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] - value_cache: shape = [num_blocks, num_heads, head_size, block_size] + value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] input_metadata: metadata for paged attention. """ block_size = value_cache.shape[3] @@ -118,6 +137,7 @@ class PagedAttention(nn.Module): query, key_cache, value_cache, + self.head_mapping, self.scale, input_metadata.block_tables, input_metadata.context_lens, @@ -143,11 +163,12 @@ class PagedAttention(nn.Module): Args: query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_heads * head_size] - value: shape = [num_tokens, num_heads * head_size] - key_cache: shape = [num_blocks, num_heads, head_size/x, + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] - value_cache: shape = [num_blocks, num_heads, head_size, block_size] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] input_metadata: metadata for paged attention. cache_event: event to wait for the cache operations to finish. @@ -157,8 +178,8 @@ class PagedAttention(nn.Module): # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_heads, self.head_size) - value = value.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) # Pre-allocate the output tensor. output = torch.empty_like(query) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 1f9befed2a..1c118ba8c0 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -26,7 +26,6 @@ from typing import Dict, List, Optional, Tuple import torch from torch import nn -import numpy as np from transformers import GPTBigCodeConfig from vllm.model_executor.input_metadata import InputMetadata @@ -55,10 +54,12 @@ class GPTBigCodeAttention(nn.Module): assert total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads + self.num_kv_heads = 1 if config.multi_query else self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dim self.scale = self.head_dim**-0.5 self.c_attn = ColumnParallelLinear(self.hidden_size, - 3 * self.hidden_size, + self.hidden_size + 2 * self.kv_dim, bias=True, gather_output=False, perform_initialization=False) @@ -69,7 +70,8 @@ class GPTBigCodeAttention(nn.Module): perform_initialization=False) self.attn = PagedAttention(self.num_heads, self.head_dim, - scale=self.scale) + scale=self.scale, + num_kv_heads=self.num_kv_heads) def forward( self, @@ -79,7 +81,8 @@ class GPTBigCodeAttention(nn.Module): cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim], + dim=-1) key_cache, value_cache = kv_cache attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata, cache_event) @@ -263,36 +266,6 @@ class GPTBigCodeForCausalLM(nn.Module): extra_rows = extra_rows.to(loaded_weight) loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) - def _expand_mqa_mha(qkv_array, n_head, head_dim): - """manipulates along axis=0 from MQA to MHA - inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim) - with n_heads for q, then 1 for k, 1 for 1 v, times head dim - return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim) - - TODO: this function is no longer needed once vllm supports MQA. - """ - qkv_array = qkv_array.numpy() - - dims_q = n_head * head_dim - # pylint: disable=unbalanced-tuple-unpacking - q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), - axis=0) - # q is fine, but k & v have not replicated shape along the first - # axis as long as MQA is not nativly supported, increase memory - # and replicated (head_dim, hidden_dim) to - # (n_heads * head_dim, hidden_dim) - if k.ndim == 2 and v.ndim == 2: - replication = (n_head, 1) # weights - else: - replication = n_head # biases - # replicate n_head times for q, v - k, v = np.tile(k, replication), np.tile(v, replication) - # concat q, k, v along the first axis - # (n_heads * head_dim, hidden_dim) - # to (3 * n_heads * head_dim, hidden_dim) - qkv_array = np.concatenate((q, k, v), axis=0) - return torch.from_numpy(qkv_array) - # For the fused QKV linear layer, manually shard the weights. if "c_attn" in name: # GPT-2's fused QKV has the shape of @@ -300,30 +273,27 @@ class GPTBigCodeForCausalLM(nn.Module): # When tensor parallelism is used, we shard the weights along # the head dimension. total_num_heads = self.config.num_attention_heads + total_num_kv_heads = (1 if self.config.multi_query else + total_num_heads) hidden_size = self.config.hidden_size head_size = hidden_size // total_num_heads + total_kv_size = head_size * total_num_kv_heads num_heads = total_num_heads // tensor_model_parallel_world_size head_start = tensor_model_parallel_rank * num_heads head_end = (tensor_model_parallel_rank + 1) * num_heads - if name.endswith(".weight"): - loaded_weight = _expand_mqa_mha(loaded_weight, - n_head=total_num_heads, - head_dim=head_size) - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size, hidden_size) - loaded_weight = loaded_weight[:, head_start:head_end, :, :] - loaded_weight = loaded_weight.reshape(-1, hidden_size) - elif name.endswith(".bias"): - loaded_weight = _expand_mqa_mha(loaded_weight, - n_head=total_num_heads, - head_dim=head_size) - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size) - loaded_weight = loaded_weight[:, head_start:head_end, :] - loaded_weight = loaded_weight.reshape(-1) - else: - raise ValueError(f"Unexpected parameter name {name}") + wq, wk, wv = torch.split( + loaded_weight, [hidden_size, total_kv_size, total_kv_size], + dim=0) + + wq = wq[head_size * head_start:head_size * head_end] + if not self.config.multi_query: + # Split the heads when using normal multi-head attention + wk = wk[head_size * head_start:head_size * head_end] + wv = wv[head_size * head_start:head_size * head_end] + # Else, keep the weights as is for multi-query attention + + loaded_weight = torch.cat([wq, wk, wv], dim=0) load_tensor_parallel_weights(param, loaded_weight, name, self._column_parallel_weights,