mirror of https://github.com/vllm-project/vllm
[Kernel] Flashinfer correctness fix for v0.1.3 (#7319)
This commit is contained in:
parent
86ab567bae
commit
ec2affa8ae
|
@ -60,8 +60,6 @@ steps:
|
|||
- vllm/
|
||||
- tests/basic_correctness
|
||||
commands:
|
||||
# This flashinfer installation will fail on AMD ROCm, so it is set as optional.
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl || true
|
||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
|
@ -157,7 +155,6 @@ steps:
|
|||
- vllm/
|
||||
- tests/models
|
||||
commands:
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
||||
- pytest -v -s models -m \"not vlm\"
|
||||
|
||||
- label: Vision Language Models Test # 42min
|
||||
|
@ -212,7 +209,6 @@ steps:
|
|||
- vllm/attention
|
||||
- tests/kernels
|
||||
commands:
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
||||
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 4
|
||||
|
||||
|
@ -331,7 +327,6 @@ steps:
|
|||
# NOTE: don't test llama model here, it seems hf implementation is buggy
|
||||
# see https://github.com/vllm-project/vllm/pull/5689 for details
|
||||
- pytest -v -s distributed/test_custom_all_reduce.py
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
||||
- TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pytest -v -s -x lora/test_mixtral.py
|
||||
|
||||
|
|
|
@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
|
|||
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
|
||||
|
|
|
@ -117,6 +117,7 @@ class FlashInferMetadata(AttentionMetadata):
|
|||
# The data type of the paged kv cache
|
||||
data_type: torch.dtype = None
|
||||
device: torch.device = torch.device("cuda")
|
||||
is_profile_run: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Refer to
|
||||
|
@ -127,7 +128,6 @@ class FlashInferMetadata(AttentionMetadata):
|
|||
raise ValueError(
|
||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
f"received {self.head_dim}.")
|
||||
self.is_profile_run = is_block_tables_empty(self.block_tables)
|
||||
|
||||
def begin_forward(self):
|
||||
if self.num_prefill_tokens > 0:
|
||||
|
@ -141,23 +141,20 @@ class FlashInferMetadata(AttentionMetadata):
|
|||
assert self.paged_kv_last_page_len is not None
|
||||
batch_size = self.query_start_loc.shape[0] - 1
|
||||
assert batch_size >= 0
|
||||
# The profile run does not read kv cache.
|
||||
# Both paged_kv_indices and paged_kv_last_page_len are empty.
|
||||
# paged_kv_indptr is a zero tensor with size batch_size + 1.
|
||||
if self.is_profile_run:
|
||||
self.paged_kv_indptr = torch.zeros(batch_size + 1,
|
||||
device=self.device)
|
||||
else:
|
||||
# We will use flash attention for profiling to
|
||||
# determine the number of blocks. Therefore,
|
||||
# we don't need to prepare the input for flashinfer for profile run.
|
||||
if not self.is_profile_run:
|
||||
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
|
||||
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
||||
self.device)
|
||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||
self.prefill_wrapper.end_forward()
|
||||
self.prefill_wrapper.begin_forward(
|
||||
self.query_start_loc, self.paged_kv_indptr,
|
||||
self.paged_kv_indices, self.paged_kv_last_page_len,
|
||||
self.num_qo_heads, self.num_kv_heads, self.head_dim,
|
||||
self.page_size)
|
||||
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
|
||||
self.device)
|
||||
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
|
||||
self.prefill_wrapper.end_forward()
|
||||
self.prefill_wrapper.begin_forward(
|
||||
self.query_start_loc, self.paged_kv_indptr,
|
||||
self.paged_kv_indices, self.paged_kv_last_page_len,
|
||||
self.num_qo_heads, self.num_kv_heads, self.head_dim,
|
||||
self.page_size)
|
||||
else:
|
||||
if not self.use_cuda_graph:
|
||||
assert self.paged_kv_indices is not None
|
||||
|
@ -249,6 +246,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||
# paged_kv_last_page_len is the length of the last page of each request
|
||||
self.paged_kv_last_page_len: List[int] = []
|
||||
|
||||
self.is_profile_run: bool = False
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
|
@ -305,6 +304,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||
# and paged_kv_last_page_len for profile run because we will
|
||||
# create dummy inputs.
|
||||
if is_profile_run:
|
||||
self.is_profile_run = is_profile_run
|
||||
return
|
||||
|
||||
block_table = block_tables[seq_id]
|
||||
|
@ -435,7 +435,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||
query_start_loc=query_start_loc,
|
||||
device=device,
|
||||
data_type=kv_cache_dtype,
|
||||
use_cuda_graph=use_captured_graph)
|
||||
use_cuda_graph=use_captured_graph,
|
||||
is_profile_run=self.is_profile_run)
|
||||
|
||||
|
||||
class FlashInferImpl(AttentionImpl):
|
||||
|
|
Loading…
Reference in New Issue