[Kernel] Flashinfer correctness fix for v0.1.3 (#7319)

This commit is contained in:
Lily Liu 2024-08-12 00:59:17 -07:00 committed by GitHub
parent 86ab567bae
commit ec2affa8ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 24 deletions

View File

@ -60,8 +60,6 @@ steps:
- vllm/
- tests/basic_correctness
commands:
# This flashinfer installation will fail on AMD ROCm, so it is set as optional.
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl || true
- pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_cpu_offload.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
@ -157,7 +155,6 @@ steps:
- vllm/
- tests/models
commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
- pytest -v -s models -m \"not vlm\"
- label: Vision Language Models Test # 42min
@ -212,7 +209,6 @@ steps:
- vllm/attention
- tests/kernels
commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4
@ -331,7 +327,6 @@ steps:
# NOTE: don't test llama model here, it seems hf implementation is buggy
# see https://github.com/vllm-project/vllm/pull/5689 for details
- pytest -v -s distributed/test_custom_all_reduce.py
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
- TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s -x lora/test_mixtral.py

View File

@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu121torch2.4-cp310-cp310-linux_x86_64.whl
#################### vLLM installation IMAGE ####################

View File

@ -117,6 +117,7 @@ class FlashInferMetadata(AttentionMetadata):
# The data type of the paged kv cache
data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
is_profile_run: bool = False
def __post_init__(self):
# Refer to
@ -127,7 +128,6 @@ class FlashInferMetadata(AttentionMetadata):
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
self.is_profile_run = is_block_tables_empty(self.block_tables)
def begin_forward(self):
if self.num_prefill_tokens > 0:
@ -141,23 +141,20 @@ class FlashInferMetadata(AttentionMetadata):
assert self.paged_kv_last_page_len is not None
batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0
# The profile run does not read kv cache.
# Both paged_kv_indices and paged_kv_last_page_len are empty.
# paged_kv_indptr is a zero tensor with size batch_size + 1.
if self.is_profile_run:
self.paged_kv_indptr = torch.zeros(batch_size + 1,
device=self.device)
else:
# We will use flash attention for profiling to
# determine the number of blocks. Therefore,
# we don't need to prepare the input for flashinfer for profile run.
if not self.is_profile_run:
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
else:
if not self.use_cuda_graph:
assert self.paged_kv_indices is not None
@ -249,6 +246,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# paged_kv_last_page_len is the length of the last page of each request
self.paged_kv_last_page_len: List[int] = []
self.is_profile_run: bool = False
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
@ -305,6 +304,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# and paged_kv_last_page_len for profile run because we will
# create dummy inputs.
if is_profile_run:
self.is_profile_run = is_profile_run
return
block_table = block_tables[seq_id]
@ -435,7 +435,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
query_start_loc=query_start_loc,
device=device,
data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph)
use_cuda_graph=use_captured_graph,
is_profile_run=self.is_profile_run)
class FlashInferImpl(AttentionImpl):