[Speculative decoding] Add ngram prompt lookup decoding (#4237)

Co-authored-by: Lei Wen <wenlei03@qiyi.com>
This commit is contained in:
leiwen83 2024-05-02 02:13:03 +08:00 committed by GitHub
parent 8b798eec75
commit b38e42fbca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1003 additions and 318 deletions

View File

@ -1,4 +1,5 @@
import asyncio
from itertools import cycle
from typing import List, Optional, Tuple, Union
import pytest
@ -185,3 +186,60 @@ def get_output_from_llm_generator(
del llm
return tokens, token_ids
def run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos = force_output_len
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
)
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
(baseline_batch_tokens,
baseline_batch_token_ids) = get_output_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
assert len(baseline_batch_token_ids) == len(prompts)
assert len(spec_batch_token_ids) == len(prompts)
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
spec_tokens) in enumerate(
zip(baseline_batch_token_ids, baseline_batch_tokens,
spec_batch_token_ids, spec_batch_tokens)):
if print_tokens:
print(f'{i=} {baseline_tokens=}')
print(f'{i=} {spec_tokens=}')
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids

View File

@ -35,7 +35,8 @@ from transformers import AutoTokenizer
from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
from .conftest import (get_output_from_llm_generator,
run_greedy_equality_correctness_test)
@pytest.mark.parametrize(
@ -545,60 +546,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size,
max_output_len=output_len,
force_output_len=True)
def run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos = force_output_len
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
)
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
(baseline_batch_tokens,
baseline_batch_token_ids) = get_output_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
assert len(baseline_batch_token_ids) == len(prompts)
assert len(spec_batch_token_ids) == len(prompts)
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
spec_tokens) in enumerate(
zip(baseline_batch_token_ids, baseline_batch_tokens,
spec_batch_token_ids, spec_batch_tokens)):
if print_tokens:
print(f'{i=} {baseline_tokens=}')
print(f'{i=} {spec_tokens=}')
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids

View File

@ -0,0 +1,172 @@
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
Since there is no model is needed for generate the proposal, we could make
the testcase much simpler than drafter multi-step one.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various ngram sizes / speculative sizes
With those tests, we can say at least, ngram spec would not break the correctess
for the target model outputs.
"""
import pytest
from .conftest import run_greedy_equality_correctness_test
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model": "JackFram/llama-68m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality on a tiny model with different batch size."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 3,
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
] + [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 1,
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)

View File

@ -6,8 +6,8 @@ import torch
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput
from vllm.spec_decode.multi_step_worker import (DraftModelTop1Proposer,
MultiStepWorker)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
from .utils import (assert_logprobs_dict_allclose, create_batch,
@ -117,8 +117,8 @@ def test_same_output_for_single_step():
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
actual_output = multi_step_worker.execute_model_multi_step(
**multi_step_execute_model_data.to_dict(), num_steps=num_steps)
actual_output, _ = multi_step_worker.sampler_output(
**multi_step_execute_model_data.to_dict(), sample_len=num_steps)
assert len(actual_output) == num_steps
actual_output = actual_output[0]
@ -200,8 +200,8 @@ def test_same_output_for_multi_step():
# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
multi_step_output = multi_step_worker.execute_model_multi_step(
**execute_model_data.to_dict(), num_steps=num_steps)
multi_step_output, _ = multi_step_worker.sampler_output(
**execute_model_data.to_dict(), sample_len=num_steps)
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
@ -266,7 +266,7 @@ def test_same_output_for_multi_step():
@torch.inference_mode()
def test_draft_proposals_full_speculation_len():
"""Verify DraftModelTop1Proposer correctly handles case where all sequences
"""Verify Top1Proposer correctly handles case where all sequences
can speculate.
"""
k = 10
@ -275,13 +275,13 @@ def test_draft_proposals_full_speculation_len():
device = 'cuda:0'
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
proposer = Top1Proposer(
worker=draft_worker,
device=device,
max_model_len=2048,
vocab_size=vocab_size,
max_proposal_len=2048,
)
draft_worker.execute_model_multi_step.return_value = [
draft_worker.sampler_output.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(batch_size,
@ -294,13 +294,13 @@ def test_draft_proposals_full_speculation_len():
device=device,
dtype=torch.long),
) for _ in range(k)
]
], True
execute_model_data, _, _ = create_batch(batch_size, k)
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)
@ -315,7 +315,7 @@ def test_draft_proposals_full_speculation_len():
@torch.inference_mode()
def test_draft_proposals_no_speculations():
"""Verify DraftModelTop1Proposer correctly handles case where no sequences
"""Verify Top1Proposer correctly handles case where no sequences
can speculate.
"""
k = 10
@ -325,11 +325,11 @@ def test_draft_proposals_no_speculations():
prompt_len = 10
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
proposer = Top1Proposer(
worker=draft_worker,
device=device,
max_model_len=prompt_len + k - 1,
vocab_size=vocab_size,
max_proposal_len=prompt_len + k - 1,
)
execute_model_data, _, _ = create_batch(batch_size,
@ -338,7 +338,7 @@ def test_draft_proposals_no_speculations():
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)
@ -353,7 +353,7 @@ def test_draft_proposals_no_speculations():
@torch.inference_mode()
def test_draft_proposals_mixed_k():
"""Verify DraftModelTop1Proposer correctly handles case some sequences can
"""Verify Top1Proposer correctly handles case some sequences can
speculate and some can't.
"""
k = 10
@ -374,14 +374,14 @@ def test_draft_proposals_mixed_k():
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
proposer = Top1Proposer(
worker=draft_worker,
device=device,
max_model_len=long_prompt_len + prev_output_token_len + k - 1,
vocab_size=vocab_size,
max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
)
draft_worker.execute_model_multi_step.return_value = [
draft_worker.sampler_output.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
@ -395,7 +395,7 @@ def test_draft_proposals_mixed_k():
device=device,
dtype=torch.long),
) for _ in range(k)
]
], True
execute_model_data, _, _ = create_batch(
batch_size,
@ -406,7 +406,7 @@ def test_draft_proposals_mixed_k():
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)

View File

@ -0,0 +1,206 @@
import torch
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from .utils import (create_execute_model_data,
create_seq_group_metadata_from_prompts, create_worker)
def test_ngram_algo_correctness_for_single_no_match():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario cannot find any candidate in one single batch
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'cuda:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(0, 3)
prompts = [
# shall find no candidate
[1, 2, 3, 4, 5, 6, 7],
]
proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),
proposal_len=proposal_len,
)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([1])
assert proposals.proposal_lens.tolist() == [0]
def test_ngram_algo_correctness_for_batches_not_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find some candidate not full in batchs
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'cuda:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(0, 3)
prompts = [
# shall find no candidate
[1, 2, 3, 4, 5, 6, 7],
# shall find candidate 12,13,14,15,16
[11, 12, 13, 14, 15, 16, 11],
# shall find candidate 23,24,25,26,21
[21, 21, 22, 23, 24, 25, 26, 21, 22],
# shall find candidate 34,35,36,37,38
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
# shall find no candidate as exceed max_proposal_len
[
31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37,
38, 31, 32, 33
],
]
proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),
proposal_len=proposal_len,
)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([5])
assert proposals.proposal_lens.tolist(
) == [proposal_len for _ in range(4)] + [0]
for i in range(proposal_len):
assert proposals.proposal_token_ids[0][i] == 0
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
assert proposals.proposal_token_ids[4][i] == -1
def test_ngram_algo_correctness_for_batches_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batchs
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'cuda:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(0, 3)
prompts = [
# shall find candidate 12,13,14,15,16
[11, 12, 13, 14, 15, 16, 11],
# shall find candidate 23,24,25,26,21
[21, 21, 22, 23, 24, 25, 26, 21, 22],
# shall find candidate 34,35,36,37,38
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
]
proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),
proposal_len=proposal_len,
)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([3])
assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)]
for i in range(proposal_len):
assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1]
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3]
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5]

View File

@ -682,6 +682,8 @@ class SpeculativeConfig:
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
use_v2_block_manager: bool,
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None.
@ -708,6 +710,10 @@ class SpeculativeConfig:
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
window, if provided.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
@ -742,39 +748,57 @@ class SpeculativeConfig:
draft_code_revision = None
draft_quantization = None
draft_model_config = ModelConfig(
model=speculative_model,
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
dtype=target_model_config.dtype,
seed=target_model_config.seed,
revision=draft_revision,
code_revision=draft_code_revision,
tokenizer_revision=target_model_config.tokenizer_revision,
max_model_len=None,
quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager,
max_context_len_to_capture=target_model_config.
max_context_len_to_capture,
max_logprobs=target_model_config.max_logprobs,
)
if speculative_model == "[ngram]":
assert (ngram_prompt_lookup_max is not None
and ngram_prompt_lookup_max > 0)
if ngram_prompt_lookup_min is None:
ngram_prompt_lookup_min = 0
else:
assert ngram_prompt_lookup_max > ngram_prompt_lookup_min
draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,
draft_model_config.max_model_len,
target_model_config.max_model_len,
))
# TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set
# draft related config as None here.
draft_model_config = target_model_config
draft_parallel_config = target_parallel_config
else:
ngram_prompt_lookup_max = 0
ngram_prompt_lookup_min = 0
draft_model_config = ModelConfig(
model=speculative_model,
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
dtype=target_model_config.dtype,
seed=target_model_config.seed,
revision=draft_revision,
code_revision=draft_code_revision,
tokenizer_revision=target_model_config.tokenizer_revision,
max_model_len=None,
quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager,
max_context_len_to_capture=target_model_config.
max_context_len_to_capture,
max_logprobs=target_model_config.max_logprobs,
)
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config))
draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len,
draft_model_config.max_model_len,
target_model_config.max_model_len,
))
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config))
return SpeculativeConfig(
draft_model_config,
draft_parallel_config,
num_speculative_tokens,
ngram_prompt_lookup_max,
ngram_prompt_lookup_min,
)
@staticmethod
@ -842,6 +866,8 @@ class SpeculativeConfig:
draft_model_config: ModelConfig,
draft_parallel_config: ParallelConfig,
num_speculative_tokens: int,
ngram_prompt_lookup_max: int,
ngram_prompt_lookup_min: int,
):
"""Create a SpeculativeConfig object.
@ -854,6 +880,8 @@ class SpeculativeConfig:
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
self.num_speculative_tokens = num_speculative_tokens
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
self._verify_args()
@ -877,7 +905,10 @@ class SpeculativeConfig:
return self.num_speculative_tokens
def __repr__(self) -> str:
draft_model = self.draft_model_config.model
if self.ngram_prompt_lookup_max > 0:
draft_model = "[ngram]"
else:
draft_model = self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"

View File

@ -75,6 +75,8 @@ class EngineArgs:
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
def __post_init__(self):
if self.tokenizer is None:
@ -449,6 +451,20 @@ class EngineArgs:
'draft model. Sequences over this length will skip '
'speculation.')
parser.add_argument(
'--ngram-prompt-lookup-max',
type=int,
default=EngineArgs.ngram_prompt_lookup_max,
help='Max size of window for ngram prompt lookup in speculative '
'decoding.')
parser.add_argument(
'--ngram-prompt-lookup-min',
type=int,
default=EngineArgs.ngram_prompt_lookup_min,
help='Min size of window for ngram prompt lookup in speculative '
'decoding.')
parser.add_argument('--model-loader-extra-config',
type=str,
default=EngineArgs.model_loader_extra_config,
@ -502,6 +518,8 @@ class EngineArgs:
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
)
scheduler_config = SchedulerConfig(

View File

@ -73,7 +73,6 @@ class GPUExecutor(ExecutorBase):
"""
assert self.speculative_config is not None
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
target_worker = self._create_worker()
@ -86,10 +85,11 @@ class GPUExecutor(ExecutorBase):
# TODO allow draft-model specific load config.
#load_config=self.load_config,
)
draft_worker = MultiStepWorker(**draft_worker_kwargs)
spec_decode_worker = SpecDecodeWorker.from_workers(
proposer_worker=draft_worker, scorer_worker=target_worker)
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
)
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")

View File

@ -333,13 +333,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
sampler_output.sampled_token_probs = spec_probs
sampler_output.sampled_token_ids = spec_sampled_tokens
target_token_ids, target_probs = sampler_output_to_torch(
[sampler_output])
[sampler_output], True)
# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
non_spec_target_token_ids, non_spec_target_probs = (
sampler_output_to_torch([sampler_output]))
sampler_output_to_torch([sampler_output], True))
return (target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs)

View File

@ -1,12 +1,11 @@
import copy
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.util import sampler_output_to_torch
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
@ -26,29 +25,37 @@ class MultiStepWorker(Worker):
super().__init__(*args, **kwargs)
# Lazy initialization list.
self._proposer: DraftModelTop1Proposer
self._proposer: Top1Proposer
def init_device(self):
super().init_device()
self._proposer = DraftModelTop1Proposer(
self._proposer = Top1Proposer(
self,
self.device,
self.max_model_len,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
def set_include_gpu_probs_tensor(self):
# Need include_gpu_probs_tensor for multi_step_worker
self.model_runner.model.sampler.include_gpu_probs_tensor = True
@torch.inference_mode()
def execute_model_multi_step(
def sampler_output(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_steps: int,
) -> List[SamplerOutput]:
"""Run the model forward pass num_steps times. Returns the list of
sampler output, one per model forward pass.
sample_len: int,
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
blocks_to_swap_out, blocks_to_copy)
@ -58,12 +65,12 @@ class MultiStepWorker(Worker):
copied_seq_group_metadata_list = self._shallow_copy_inputs(
seq_group_metadata_list)
# Assert enough KV space for num_steps tokens per sequence.
self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
# Assert enough KV space for sample_len tokens per sequence.
self._assert_enough_kv_space(seq_group_metadata_list, sample_len)
# Run model num_steps times.
# Run model sample_len times.
model_outputs = []
for _ in range(num_steps):
for _ in range(sample_len):
model_output = super().execute_model(
seq_group_metadata_list=copied_seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
@ -78,7 +85,7 @@ class MultiStepWorker(Worker):
copied_seq_group_metadata_list)
model_outputs.append(model_output)
return model_outputs
return model_outputs, True
def get_spec_proposals(
self,
@ -206,171 +213,3 @@ class MultiStepWorker(Worker):
for seq_group_metadata in seq_group_metadata_list):
raise NotImplementedError(
"MultiStepWorker does not support beam search.")
class DraftModelTop1Proposer(SpeculativeProposer):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def __init__(
self,
draft_worker: MultiStepWorker,
device: str,
max_model_len: int,
vocab_size: int,
):
self._draft_worker = draft_worker
self._device = device
self._max_model_len = max_model_len
self._vocab_size = vocab_size
def get_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
# Split speculative- and non-speculative- sequences.
(proposal_lens, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices) = self._split_by_max_model_len(
seq_group_metadata_list, max_proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
# sequences.
maybe_sampler_output = self._draft_worker.execute_model_multi_step(
seq_group_metadata_list=nonzero_proposal_len_seqs,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
num_steps=max_proposal_len,
)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
batch_size=len(seq_group_metadata_list),
max_proposal_len=max_proposal_len,
maybe_sampler_output=maybe_sampler_output,
proposal_lens=proposal_lens,
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
)
proposals = SpeculativeProposals(
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
)
return proposals
def _split_by_max_model_len(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
max_proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Determine which sequences would exceed the max model length.
"""
proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
if seq_len + max_proposal_len < self._max_model_len:
proposal_lens.append(max_proposal_len)
nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i)
else:
proposal_lens.append(0)
return (proposal_lens, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices)
def _merge_outputs(
self,
batch_size: int,
max_proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.full(size=(
batch_size,
max_proposal_len,
),
fill_value=-1,
dtype=torch.long,
device=self._device)
proposal_probs = torch.zeros(batch_size,
max_proposal_len,
self._vocab_size,
dtype=torch.float32,
device=self._device)
proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs = sampler_output_to_torch(
sampler_output)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = torch.full(size=(batch_size,
*proposal_tokens.shape[1:]),
fill_value=-1,
dtype=torch.long,
device=self._device)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros(batch_size,
*proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = (entire_proposal_tokens,
entire_proposal_probs)
proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len
return proposal_tokens, proposal_probs, proposal_lens_tensor

View File

@ -0,0 +1,190 @@
from typing import Dict, List, Optional, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class NGramWorker(LoraNotSupportedWorkerBase):
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implement prompt lookup decoding,
and in future we may also do RAG type drafter and other scenerios
which don't rely on LLM model to give proposals.
"""
def __init__(self, *args, **kwargs):
# Get local_rank/vocab_size from kwargs attribute
self.local_rank = kwargs["local_rank"]
self.vocab_size = kwargs["model_config"].get_vocab_size()
# Lazy initialization list.
self._proposer: Top1Proposer
def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
ngram_prompt_lookup_max: int):
# Search valid candidate window between
# ngram_prompt_lookup_min/ngram_prompt_lookup_max
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
def init_device(self):
self.device = torch.device(f"cuda:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None
# Current only support Top1Proposer
self._proposer = Top1Proposer(
self,
device=self.device,
vocab_size=self.vocab_size,
)
def set_include_gpu_probs_tensor(self):
# NGram don't need gpu sampler
pass
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
) -> None:
"""NGram doesn't depend on model execution, just pass this function"""
pass
def determine_num_available_blocks(self) -> None:
"""NGram doesn't depend on model execution, no need to check blocks"""
pass
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""As there is no cache need to handle, just pass this function"""
pass
def get_cache_block_size_bytes(self):
"""Return the size of a cache block in bytes."""
return 0
def sampler_output(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
sample_len: int,
) -> Tuple[Optional[List[SamplerOutput]], bool]:
"""NGram match algo to pick proposal candidate. Returns the list of
sampler output, one per SequenceGroupMetadata.
For ngram worker, we already done needed transposed internal, so the
indicator pass to sampler_output_to_torch shall be False.
"""
self._raise_if_unsupported(
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
)
arr = []
has_spec_out = False
for seq_group_metadata in seq_group_metadata_list:
seq_data = next(iter(seq_group_metadata.seq_data.values()))
input_ids = torch.as_tensor(seq_data.get_token_ids(),
dtype=torch.long,
device=self.device)
input_length = seq_data.get_len()
for ngram_size in range(
min(self.ngram_prompt_lookup_max, input_length - 1),
self.ngram_prompt_lookup_min,
-1,
):
ngram_tensor = input_ids[-1 * ngram_size:]
windows = input_ids.unfold(dimension=0,
size=ngram_size,
step=1)
matches = (windows == ngram_tensor).all(dim=1)
match_indices = matches.nonzero(as_tuple=True)[0]
if match_indices.size()[0] > 1:
has_spec_out = True
res = seq_data.get_token_ids()
res = res[match_indices[0] + ngram_size:match_indices[0] +
ngram_size + sample_len]
res_len = len(res)
# pad 0 towards output as sample_len tokens required
res += [0] * (sample_len - res_len)
break
else:
# if no candidate found, fill with 0
res = [0] * sample_len
arr.append(res)
if not has_spec_out:
return None, False
outputs = []
token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device)
indices = token_ids.unsqueeze(2)
token_probs = torch.zeros(
(len(seq_group_metadata_list), sample_len, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
token_probs.scatter_(2, indices, 1)
for i in range(len(seq_group_metadata_list)):
outputs.append(
SamplerOutput(
outputs=None,
sampled_token_probs=token_probs[i],
sampled_token_ids=token_ids[i],
))
return outputs, False
def get_spec_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
max_proposal_len,
)
def _raise_if_unsupported(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
"""NGramWorker does not yet implement support for cache swap
operations or beam search.
"""
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
raise NotImplementedError(
"NGramWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in seq_group_metadata_list):
raise NotImplementedError(
"NGramWorker does not support beam search.")

View File

@ -12,6 +12,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
@ -48,8 +49,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
@classmethod
def from_workers(cls, proposer_worker: MultiStepWorker,
scorer_worker: WorkerBase) -> "SpecDecodeWorker":
def create_worker(
cls,
scorer_worker: WorkerBase,
draft_worker_kwargs,
) -> "SpecDecodeWorker":
if "ngram_prompt_lookup_max" in draft_worker_kwargs:
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
else:
ngram_prompt_lookup_max = 0
if ngram_prompt_lookup_max > 0:
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
else:
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
@ -59,7 +79,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def __init__(
self,
proposer_worker: MultiStepWorker,
proposer_worker: WorkerBase,
scorer_worker: WorkerBase,
rejection_sampler: RejectionSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None,
@ -134,8 +154,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True
(self.proposer_worker.model_runner.model.sampler.
include_gpu_probs_tensor) = True
self.proposer_worker.set_include_gpu_probs_tensor()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.
@ -183,8 +202,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"speculative decoding "
"requires non-None seq_group_metadata_list")
logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
num_lookahead_slots)
#logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
# num_lookahead_slots)
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
@ -216,7 +235,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer and scorer model so that the KV cache is consistent between the
two.
"""
logger.info("run proposer worker no spec")
#logger.info("run proposer worker no spec")
self.proposer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
@ -225,7 +244,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
blocks_to_copy=blocks_to_copy,
)
logger.info("run target worker no spec")
#logger.info("run target worker no spec")
sampler_output = self.scorer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
@ -259,7 +278,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sequence.
"""
logger.info("get spec proposals")
#logger.info("get spec proposals")
# Generate proposals using draft worker.
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
@ -268,7 +287,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k)
logger.info("score proposals")
#logger.info("score proposals")
proposal_scores = self.scorer.score_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
@ -278,11 +297,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals,
)
logger.info("verify proposals")
#logger.info("verify proposals")
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
proposal_scores, proposals, k)
logger.info("create output list")
#logger.info("create output list")
return self._create_output_sampler_list(seq_group_metadata_list,
accepted_token_ids, k)

View File

@ -0,0 +1,200 @@
from typing import Dict, List, Optional, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.util import sampler_output_to_torch
from vllm.worker.worker_base import WorkerBase
class Top1Proposer(SpeculativeProposer):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def __init__(
self,
worker: WorkerBase,
device: str,
vocab_size: int,
max_proposal_len: Optional[int] = None,
):
self._worker = worker
self._device = device
self.max_proposal_len = max_proposal_len
self._vocab_size = vocab_size
def get_proposals(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
proposal_len: int,
) -> SpeculativeProposals:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
# Split speculative- and non-speculative- sequences.
(
proposal_lens,
nonzero_proposal_len_seqs,
nonzero_proposal_len_indices,
) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
# sequences.
# If sampler_transposed is true, then maybe_sampler_output's
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# in batch size list
maybe_sampler_output, transposed = self._worker.sampler_output(
seq_group_metadata_list=nonzero_proposal_len_seqs,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
sample_len=proposal_len,
)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
transposed = False
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
batch_size=len(seq_group_metadata_list),
proposal_len=proposal_len,
maybe_sampler_output=maybe_sampler_output,
proposal_lens=proposal_lens,
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
sampler_transposed=transposed,
)
proposals = SpeculativeProposals(
proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
)
return proposals
def _split_by_max_model_len(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Determine which sequences would exceed the max model length."""
proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# If max_proposal_len is defined, then we shall no exccess this
# quota for nonzero_proposal
if (self.max_proposal_len is None
or seq_len + proposal_len < self.max_proposal_len):
proposal_lens.append(proposal_len)
nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i)
else:
proposal_lens.append(0)
return (
proposal_lens,
nonzero_proposal_len_seqs,
nonzero_proposal_len_indices,
)
def _merge_outputs(
self,
batch_size: int,
proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
sampler_transposed: bool,
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.full(
size=(
batch_size,
proposal_len,
),
fill_value=-1,
dtype=torch.long,
device=self._device,
)
proposal_probs = torch.zeros(
batch_size,
proposal_len,
self._vocab_size,
dtype=torch.float32,
device=self._device,
)
proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs = sampler_output_to_torch(
sampler_output, sampler_transposed)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = torch.full(
size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1,
dtype=torch.long,
device=self._device,
)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros(
batch_size,
*proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device,
)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = (
entire_proposal_tokens,
entire_proposal_probs,
)
proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
return proposal_tokens, proposal_probs, proposal_lens_tensor

View File

@ -49,10 +49,13 @@ def split_batch_by_proposal_len(
def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput],
) -> Tuple[torch.Tensor, torch.Tensor]:
sampler_output_list: List[SamplerOutput],
sampler_transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
we need do additional tensor transpose logic here.
Returns:
sampled_token_ids: torch.Tensor
shape: [batch_size, len(sampler_output_list)]
@ -68,7 +71,10 @@ def sampler_output_to_torch(
for sampler_output in sampler_output_list
],
dim=0,
).transpose(0, 1)
)
if sampler_transposed:
sampled_token_probs = sampled_token_probs.transpose(0, 1)
# shape: [batch_size, num_sampler_output]
sampled_token_ids = torch.stack(
@ -77,7 +83,9 @@ def sampler_output_to_torch(
for sampler_output in sampler_output_list
],
dim=0,
).transpose(0, 1)
)
if sampler_transposed:
sampled_token_ids = sampled_token_ids.transpose(0, 1)
return sampled_token_ids, sampled_token_probs