mirror of https://github.com/vllm-project/vllm
[Speculative decoding] Add ngram prompt lookup decoding (#4237)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
This commit is contained in:
parent
8b798eec75
commit
b38e42fbca
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
from itertools import cycle
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
|
@ -185,3 +186,60 @@ def get_output_from_llm_generator(
|
|||
del llm
|
||||
|
||||
return tokens, token_ids
|
||||
|
||||
|
||||
def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
print_tokens: bool = False):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero.
|
||||
"""
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
# If the test requires that we generated max_output_len tokens, then set the
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
|
||||
(baseline_batch_tokens,
|
||||
baseline_batch_token_ids) = get_output_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
assert len(baseline_batch_token_ids) == len(prompts)
|
||||
assert len(spec_batch_token_ids) == len(prompts)
|
||||
|
||||
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
|
||||
spec_tokens) in enumerate(
|
||||
zip(baseline_batch_token_ids, baseline_batch_tokens,
|
||||
spec_batch_token_ids, spec_batch_tokens)):
|
||||
if print_tokens:
|
||||
print(f'{i=} {baseline_tokens=}')
|
||||
print(f'{i=} {spec_tokens=}')
|
||||
print(f'{i=} {baseline_token_ids=}')
|
||||
print(f'{i=} {spec_token_ids=}')
|
||||
assert baseline_token_ids == spec_token_ids
|
||||
|
|
|
@ -35,7 +35,8 @@ from transformers import AutoTokenizer
|
|||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import get_output_from_llm_generator
|
||||
from .conftest import (get_output_from_llm_generator,
|
||||
run_greedy_equality_correctness_test)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -545,60 +546,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
|||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
print_tokens: bool = False):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero.
|
||||
"""
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
# If the test requires that we generated max_output_len tokens, then set the
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
|
||||
(baseline_batch_tokens,
|
||||
baseline_batch_token_ids) = get_output_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
assert len(baseline_batch_token_ids) == len(prompts)
|
||||
assert len(spec_batch_token_ids) == len(prompts)
|
||||
|
||||
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
|
||||
spec_tokens) in enumerate(
|
||||
zip(baseline_batch_token_ids, baseline_batch_tokens,
|
||||
spec_batch_token_ids, spec_batch_tokens)):
|
||||
if print_tokens:
|
||||
print(f'{i=} {baseline_tokens=}')
|
||||
print(f'{i=} {spec_tokens=}')
|
||||
print(f'{i=} {baseline_token_ids=}')
|
||||
print(f'{i=} {spec_token_ids=}')
|
||||
assert baseline_token_ids == spec_token_ids
|
|
@ -0,0 +1,172 @@
|
|||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
|
||||
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
|
||||
Since there is no model is needed for generate the proposal, we could make
|
||||
the testcase much simpler than drafter multi-step one.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various ngram sizes / speculative sizes
|
||||
|
||||
With those tests, we can say at least, ngram spec would not break the correctess
|
||||
for the target model outputs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_greedy_equality_correctness_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 8,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": k,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
] + [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": k,
|
||||
"ngram_prompt_lookup_max": 1,
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram_prompt_lookup_max.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
|
@ -6,8 +6,8 @@ import torch
|
|||
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.spec_decode.multi_step_worker import (DraftModelTop1Proposer,
|
||||
MultiStepWorker)
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
from .utils import (assert_logprobs_dict_allclose, create_batch,
|
||||
|
@ -117,8 +117,8 @@ def test_same_output_for_single_step():
|
|||
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
actual_output = multi_step_worker.execute_model_multi_step(
|
||||
**multi_step_execute_model_data.to_dict(), num_steps=num_steps)
|
||||
actual_output, _ = multi_step_worker.sampler_output(
|
||||
**multi_step_execute_model_data.to_dict(), sample_len=num_steps)
|
||||
assert len(actual_output) == num_steps
|
||||
actual_output = actual_output[0]
|
||||
|
||||
|
@ -200,8 +200,8 @@ def test_same_output_for_multi_step():
|
|||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
multi_step_output = multi_step_worker.execute_model_multi_step(
|
||||
**execute_model_data.to_dict(), num_steps=num_steps)
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
**execute_model_data.to_dict(), sample_len=num_steps)
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
|
@ -266,7 +266,7 @@ def test_same_output_for_multi_step():
|
|||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_full_speculation_len():
|
||||
"""Verify DraftModelTop1Proposer correctly handles case where all sequences
|
||||
"""Verify Top1Proposer correctly handles case where all sequences
|
||||
can speculate.
|
||||
"""
|
||||
k = 10
|
||||
|
@ -275,13 +275,13 @@ def test_draft_proposals_full_speculation_len():
|
|||
device = 'cuda:0'
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = DraftModelTop1Proposer(
|
||||
draft_worker=draft_worker,
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
max_model_len=2048,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=2048,
|
||||
)
|
||||
draft_worker.execute_model_multi_step.return_value = [
|
||||
draft_worker.sampler_output.return_value = [
|
||||
SamplerOutput(
|
||||
outputs=[],
|
||||
sampled_token_probs=torch.rand(batch_size,
|
||||
|
@ -294,13 +294,13 @@ def test_draft_proposals_full_speculation_len():
|
|||
device=device,
|
||||
dtype=torch.long),
|
||||
) for _ in range(k)
|
||||
]
|
||||
], True
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
max_proposal_len=k,
|
||||
proposal_len=k,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
|
@ -315,7 +315,7 @@ def test_draft_proposals_full_speculation_len():
|
|||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_no_speculations():
|
||||
"""Verify DraftModelTop1Proposer correctly handles case where no sequences
|
||||
"""Verify Top1Proposer correctly handles case where no sequences
|
||||
can speculate.
|
||||
"""
|
||||
k = 10
|
||||
|
@ -325,11 +325,11 @@ def test_draft_proposals_no_speculations():
|
|||
prompt_len = 10
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = DraftModelTop1Proposer(
|
||||
draft_worker=draft_worker,
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
max_model_len=prompt_len + k - 1,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=prompt_len + k - 1,
|
||||
)
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size,
|
||||
|
@ -338,7 +338,7 @@ def test_draft_proposals_no_speculations():
|
|||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
max_proposal_len=k,
|
||||
proposal_len=k,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
|
@ -353,7 +353,7 @@ def test_draft_proposals_no_speculations():
|
|||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_mixed_k():
|
||||
"""Verify DraftModelTop1Proposer correctly handles case some sequences can
|
||||
"""Verify Top1Proposer correctly handles case some sequences can
|
||||
speculate and some can't.
|
||||
"""
|
||||
k = 10
|
||||
|
@ -374,14 +374,14 @@ def test_draft_proposals_mixed_k():
|
|||
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = DraftModelTop1Proposer(
|
||||
draft_worker=draft_worker,
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
max_model_len=long_prompt_len + prev_output_token_len + k - 1,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
|
||||
)
|
||||
|
||||
draft_worker.execute_model_multi_step.return_value = [
|
||||
draft_worker.sampler_output.return_value = [
|
||||
SamplerOutput(
|
||||
outputs=[],
|
||||
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
|
||||
|
@ -395,7 +395,7 @@ def test_draft_proposals_mixed_k():
|
|||
device=device,
|
||||
dtype=torch.long),
|
||||
) for _ in range(k)
|
||||
]
|
||||
], True
|
||||
|
||||
execute_model_data, _, _ = create_batch(
|
||||
batch_size,
|
||||
|
@ -406,7 +406,7 @@ def test_draft_proposals_mixed_k():
|
|||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
max_proposal_len=k,
|
||||
proposal_len=k,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
|
|
|
@ -0,0 +1,206 @@
|
|||
import torch
|
||||
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from .utils import (create_execute_model_data,
|
||||
create_seq_group_metadata_from_prompts, create_worker)
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_single_no_match():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario cannot find any candidate in one single batch
|
||||
"""
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window (0, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(0, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find no candidate
|
||||
[1, 2, 3, 4, 5, 6, 7],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
proposal_len=proposal_len,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([1])
|
||||
assert proposals.proposal_lens.tolist() == [0]
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario find some candidate not full in batchs
|
||||
"""
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window (0, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(0, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find no candidate
|
||||
[1, 2, 3, 4, 5, 6, 7],
|
||||
# shall find candidate 12,13,14,15,16
|
||||
[11, 12, 13, 14, 15, 16, 11],
|
||||
# shall find candidate 23,24,25,26,21
|
||||
[21, 21, 22, 23, 24, 25, 26, 21, 22],
|
||||
# shall find candidate 34,35,36,37,38
|
||||
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
|
||||
# shall find no candidate as exceed max_proposal_len
|
||||
[
|
||||
31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37,
|
||||
38, 31, 32, 33
|
||||
],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
proposal_len=proposal_len,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([5])
|
||||
|
||||
assert proposals.proposal_lens.tolist(
|
||||
) == [proposal_len for _ in range(4)] + [0]
|
||||
|
||||
for i in range(proposal_len):
|
||||
assert proposals.proposal_token_ids[0][i] == 0
|
||||
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
|
||||
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
|
||||
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
|
||||
assert proposals.proposal_token_ids[4][i] == -1
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_batches_match_all():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario find candidate in all batchs
|
||||
"""
|
||||
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window (0, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(0, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find candidate 12,13,14,15,16
|
||||
[11, 12, 13, 14, 15, 16, 11],
|
||||
# shall find candidate 23,24,25,26,21
|
||||
[21, 21, 22, 23, 24, 25, 26, 21, 22],
|
||||
# shall find candidate 34,35,36,37,38
|
||||
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
proposal_len=proposal_len,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([3])
|
||||
|
||||
assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)]
|
||||
|
||||
for i in range(proposal_len):
|
||||
assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1]
|
||||
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3]
|
||||
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5]
|
|
@ -682,6 +682,8 @@ class SpeculativeConfig:
|
|||
speculative_max_model_len: Optional[int],
|
||||
enable_chunked_prefill: bool,
|
||||
use_v2_block_manager: bool,
|
||||
ngram_prompt_lookup_max: Optional[int],
|
||||
ngram_prompt_lookup_min: Optional[int],
|
||||
) -> Optional["SpeculativeConfig"]:
|
||||
"""Create a SpeculativeConfig if possible, else return None.
|
||||
|
||||
|
@ -708,6 +710,10 @@ class SpeculativeConfig:
|
|||
use_v2_block_manager (bool): Whether vLLM is configured to use the
|
||||
v2 block manager or not. Used for raising an error since the v2
|
||||
block manager is required with spec decode.
|
||||
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
|
||||
window, if provided.
|
||||
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
|
||||
window, if provided.
|
||||
|
||||
Returns:
|
||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
||||
|
@ -742,39 +748,57 @@ class SpeculativeConfig:
|
|||
draft_code_revision = None
|
||||
draft_quantization = None
|
||||
|
||||
draft_model_config = ModelConfig(
|
||||
model=speculative_model,
|
||||
tokenizer=target_model_config.tokenizer,
|
||||
tokenizer_mode=target_model_config.tokenizer_mode,
|
||||
trust_remote_code=target_model_config.trust_remote_code,
|
||||
dtype=target_model_config.dtype,
|
||||
seed=target_model_config.seed,
|
||||
revision=draft_revision,
|
||||
code_revision=draft_code_revision,
|
||||
tokenizer_revision=target_model_config.tokenizer_revision,
|
||||
max_model_len=None,
|
||||
quantization=draft_quantization,
|
||||
enforce_eager=target_model_config.enforce_eager,
|
||||
max_context_len_to_capture=target_model_config.
|
||||
max_context_len_to_capture,
|
||||
max_logprobs=target_model_config.max_logprobs,
|
||||
)
|
||||
if speculative_model == "[ngram]":
|
||||
assert (ngram_prompt_lookup_max is not None
|
||||
and ngram_prompt_lookup_max > 0)
|
||||
if ngram_prompt_lookup_min is None:
|
||||
ngram_prompt_lookup_min = 0
|
||||
else:
|
||||
assert ngram_prompt_lookup_max > ngram_prompt_lookup_min
|
||||
|
||||
draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
speculative_max_model_len,
|
||||
draft_model_config.max_model_len,
|
||||
target_model_config.max_model_len,
|
||||
))
|
||||
# TODO: current we still need extract vocab_size from target model
|
||||
# config, in future, we may try refactor it out, and set
|
||||
# draft related config as None here.
|
||||
draft_model_config = target_model_config
|
||||
draft_parallel_config = target_parallel_config
|
||||
else:
|
||||
ngram_prompt_lookup_max = 0
|
||||
ngram_prompt_lookup_min = 0
|
||||
draft_model_config = ModelConfig(
|
||||
model=speculative_model,
|
||||
tokenizer=target_model_config.tokenizer,
|
||||
tokenizer_mode=target_model_config.tokenizer_mode,
|
||||
trust_remote_code=target_model_config.trust_remote_code,
|
||||
dtype=target_model_config.dtype,
|
||||
seed=target_model_config.seed,
|
||||
revision=draft_revision,
|
||||
code_revision=draft_code_revision,
|
||||
tokenizer_revision=target_model_config.tokenizer_revision,
|
||||
max_model_len=None,
|
||||
quantization=draft_quantization,
|
||||
enforce_eager=target_model_config.enforce_eager,
|
||||
max_context_len_to_capture=target_model_config.
|
||||
max_context_len_to_capture,
|
||||
max_logprobs=target_model_config.max_logprobs,
|
||||
)
|
||||
|
||||
draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
target_parallel_config))
|
||||
draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
speculative_max_model_len,
|
||||
draft_model_config.max_model_len,
|
||||
target_model_config.max_model_len,
|
||||
))
|
||||
|
||||
draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
target_parallel_config))
|
||||
|
||||
return SpeculativeConfig(
|
||||
draft_model_config,
|
||||
draft_parallel_config,
|
||||
num_speculative_tokens,
|
||||
ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -842,6 +866,8 @@ class SpeculativeConfig:
|
|||
draft_model_config: ModelConfig,
|
||||
draft_parallel_config: ParallelConfig,
|
||||
num_speculative_tokens: int,
|
||||
ngram_prompt_lookup_max: int,
|
||||
ngram_prompt_lookup_min: int,
|
||||
):
|
||||
"""Create a SpeculativeConfig object.
|
||||
|
||||
|
@ -854,6 +880,8 @@ class SpeculativeConfig:
|
|||
self.draft_model_config = draft_model_config
|
||||
self.draft_parallel_config = draft_parallel_config
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
|
||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
|
||||
|
||||
self._verify_args()
|
||||
|
||||
|
@ -877,7 +905,10 @@ class SpeculativeConfig:
|
|||
return self.num_speculative_tokens
|
||||
|
||||
def __repr__(self) -> str:
|
||||
draft_model = self.draft_model_config.model
|
||||
if self.ngram_prompt_lookup_max > 0:
|
||||
draft_model = "[ngram]"
|
||||
else:
|
||||
draft_model = self.draft_model_config.model
|
||||
num_spec_tokens = self.num_speculative_tokens
|
||||
return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
|
||||
|
||||
|
|
|
@ -75,6 +75,8 @@ class EngineArgs:
|
|||
speculative_model: Optional[str] = None
|
||||
num_speculative_tokens: Optional[int] = None
|
||||
speculative_max_model_len: Optional[int] = None
|
||||
ngram_prompt_lookup_max: Optional[int] = None
|
||||
ngram_prompt_lookup_min: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
|
@ -449,6 +451,20 @@ class EngineArgs:
|
|||
'draft model. Sequences over this length will skip '
|
||||
'speculation.')
|
||||
|
||||
parser.add_argument(
|
||||
'--ngram-prompt-lookup-max',
|
||||
type=int,
|
||||
default=EngineArgs.ngram_prompt_lookup_max,
|
||||
help='Max size of window for ngram prompt lookup in speculative '
|
||||
'decoding.')
|
||||
|
||||
parser.add_argument(
|
||||
'--ngram-prompt-lookup-min',
|
||||
type=int,
|
||||
default=EngineArgs.ngram_prompt_lookup_min,
|
||||
help='Min size of window for ngram prompt lookup in speculative '
|
||||
'decoding.')
|
||||
|
||||
parser.add_argument('--model-loader-extra-config',
|
||||
type=str,
|
||||
default=EngineArgs.model_loader_extra_config,
|
||||
|
@ -502,6 +518,8 @@ class EngineArgs:
|
|||
speculative_max_model_len=self.speculative_max_model_len,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
use_v2_block_manager=self.use_v2_block_manager,
|
||||
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
||||
)
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
|
|
|
@ -73,7 +73,6 @@ class GPUExecutor(ExecutorBase):
|
|||
"""
|
||||
assert self.speculative_config is not None
|
||||
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||
|
||||
target_worker = self._create_worker()
|
||||
|
@ -86,10 +85,11 @@ class GPUExecutor(ExecutorBase):
|
|||
# TODO allow draft-model specific load config.
|
||||
#load_config=self.load_config,
|
||||
)
|
||||
draft_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
|
||||
spec_decode_worker = SpecDecodeWorker.from_workers(
|
||||
proposer_worker=draft_worker, scorer_worker=target_worker)
|
||||
spec_decode_worker = SpecDecodeWorker.create_worker(
|
||||
scorer_worker=target_worker,
|
||||
draft_worker_kwargs=draft_worker_kwargs,
|
||||
)
|
||||
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"GPUExecutor only supports single GPU.")
|
||||
|
|
|
@ -333,13 +333,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||
sampler_output.sampled_token_probs = spec_probs
|
||||
sampler_output.sampled_token_ids = spec_sampled_tokens
|
||||
target_token_ids, target_probs = sampler_output_to_torch(
|
||||
[sampler_output])
|
||||
[sampler_output], True)
|
||||
|
||||
# Convert non-speculative output tokens to tensors.
|
||||
sampler_output.sampled_token_probs = non_spec_probs
|
||||
sampler_output.sampled_token_ids = non_spec_sampled_tokens
|
||||
non_spec_target_token_ids, non_spec_target_probs = (
|
||||
sampler_output_to_torch([sampler_output]))
|
||||
sampler_output_to_torch([sampler_output], True))
|
||||
|
||||
return (target_token_ids, target_probs, non_spec_target_token_ids,
|
||||
non_spec_target_probs)
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import copy
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
|
@ -26,29 +25,37 @@ class MultiStepWorker(Worker):
|
|||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Lazy initialization list.
|
||||
self._proposer: DraftModelTop1Proposer
|
||||
self._proposer: Top1Proposer
|
||||
|
||||
def init_device(self):
|
||||
super().init_device()
|
||||
|
||||
self._proposer = DraftModelTop1Proposer(
|
||||
self._proposer = Top1Proposer(
|
||||
self,
|
||||
self.device,
|
||||
self.max_model_len,
|
||||
self.vocab_size,
|
||||
max_proposal_len=self.max_model_len,
|
||||
)
|
||||
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
# Need include_gpu_probs_tensor for multi_step_worker
|
||||
self.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model_multi_step(
|
||||
def sampler_output(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_steps: int,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Run the model forward pass num_steps times. Returns the list of
|
||||
sampler output, one per model forward pass.
|
||||
sample_len: int,
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
sampler output, one per model forward pass, along with indicator of
|
||||
whether torch tensor in sampler output need to be transposed in latter
|
||||
sampler_output_to_torch logic.
|
||||
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
|
||||
blocks_to_swap_out, blocks_to_copy)
|
||||
|
@ -58,12 +65,12 @@ class MultiStepWorker(Worker):
|
|||
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
||||
seq_group_metadata_list)
|
||||
|
||||
# Assert enough KV space for num_steps tokens per sequence.
|
||||
self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
|
||||
# Assert enough KV space for sample_len tokens per sequence.
|
||||
self._assert_enough_kv_space(seq_group_metadata_list, sample_len)
|
||||
|
||||
# Run model num_steps times.
|
||||
# Run model sample_len times.
|
||||
model_outputs = []
|
||||
for _ in range(num_steps):
|
||||
for _ in range(sample_len):
|
||||
model_output = super().execute_model(
|
||||
seq_group_metadata_list=copied_seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
|
@ -78,7 +85,7 @@ class MultiStepWorker(Worker):
|
|||
copied_seq_group_metadata_list)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
return model_outputs
|
||||
return model_outputs, True
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
|
@ -206,171 +213,3 @@ class MultiStepWorker(Worker):
|
|||
for seq_group_metadata in seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support beam search.")
|
||||
|
||||
|
||||
class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
"""Helper class which separates out sequences which would exceed the max
|
||||
model length when speculated upon.
|
||||
|
||||
This allows combinations of models such as JackFram/llama-68m draft with
|
||||
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
|
||||
2048 while Llama2-13b has max_position_embeddings of 4096.
|
||||
|
||||
We treat the sequences which exceed the proposal draft model length as
|
||||
"non-spec sequences". Essentially they skip the draft model and go through
|
||||
normal decoding in the target model.
|
||||
|
||||
Currently, only proposal_lens of 0 and k are supported, where k is a global
|
||||
batch proposal length. In the future vLLM should support per-sequence
|
||||
proposal lengths.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
draft_worker: MultiStepWorker,
|
||||
device: str,
|
||||
max_model_len: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
self._draft_worker = draft_worker
|
||||
self._device = device
|
||||
self._max_model_len = max_model_len
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
def get_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
max_proposal_len: int,
|
||||
) -> SpeculativeProposals:
|
||||
"""Get speculative proposals given the input batch.
|
||||
|
||||
Sequences which would exceed the max model length are skipped during
|
||||
speculation.
|
||||
"""
|
||||
|
||||
# Split speculative- and non-speculative- sequences.
|
||||
(proposal_lens, nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices) = self._split_by_max_model_len(
|
||||
seq_group_metadata_list, max_proposal_len)
|
||||
|
||||
if nonzero_proposal_len_seqs:
|
||||
# Speculate tokens using the draft worker for the speculative
|
||||
# sequences.
|
||||
maybe_sampler_output = self._draft_worker.execute_model_multi_step(
|
||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_steps=max_proposal_len,
|
||||
)
|
||||
else:
|
||||
# If no sequences can be speculated, set sampler output to None.
|
||||
maybe_sampler_output = None
|
||||
|
||||
# Combine speculative- and non-speculative sequences into the same
|
||||
# representation.
|
||||
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
|
||||
batch_size=len(seq_group_metadata_list),
|
||||
max_proposal_len=max_proposal_len,
|
||||
maybe_sampler_output=maybe_sampler_output,
|
||||
proposal_lens=proposal_lens,
|
||||
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
|
||||
)
|
||||
|
||||
proposals = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_tokens,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens,
|
||||
)
|
||||
|
||||
return proposals
|
||||
|
||||
def _split_by_max_model_len(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
max_proposal_len: int,
|
||||
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
|
||||
"""Determine which sequences would exceed the max model length.
|
||||
"""
|
||||
|
||||
proposal_lens: List[int] = []
|
||||
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
||||
nonzero_proposal_len_indices: List[int] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
seq_len = seq_data.get_len()
|
||||
|
||||
# Currently only proposal lens of 0 or the global batch proposal len
|
||||
# are supported.
|
||||
if seq_len + max_proposal_len < self._max_model_len:
|
||||
proposal_lens.append(max_proposal_len)
|
||||
nonzero_proposal_len_seqs.append(seq_group_metadata)
|
||||
nonzero_proposal_len_indices.append(i)
|
||||
else:
|
||||
proposal_lens.append(0)
|
||||
|
||||
return (proposal_lens, nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices)
|
||||
|
||||
def _merge_outputs(
|
||||
self,
|
||||
batch_size: int,
|
||||
max_proposal_len: int,
|
||||
maybe_sampler_output: Optional[SamplerOutput],
|
||||
proposal_lens: List[int],
|
||||
nonzero_proposal_len_indices: List[int],
|
||||
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
|
||||
"""After speculations are produced, merge the speculation results with
|
||||
the skipped sequences.
|
||||
"""
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty proposals.
|
||||
proposal_tokens = torch.full(size=(
|
||||
batch_size,
|
||||
max_proposal_len,
|
||||
),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_probs = torch.zeros(batch_size,
|
||||
max_proposal_len,
|
||||
self._vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=self._device)
|
||||
proposal_lens_tensor = torch.zeros(len(proposal_lens),
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
proposal_tokens, proposal_probs = sampler_output_to_torch(
|
||||
sampler_output)
|
||||
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
||||
|
||||
entire_proposal_tokens = torch.full(size=(batch_size,
|
||||
*proposal_tokens.shape[1:]),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
||||
entire_proposal_probs = torch.zeros(batch_size,
|
||||
*proposal_probs.shape[1:],
|
||||
dtype=torch.float32,
|
||||
device=self._device)
|
||||
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||
|
||||
proposal_tokens, proposal_probs = (entire_proposal_tokens,
|
||||
entire_proposal_probs)
|
||||
|
||||
proposal_lens_tensor = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len
|
||||
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
|
|
@ -0,0 +1,190 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
|
||||
|
||||
class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
"""NGramWorker provides a light drafter without need for model.
|
||||
|
||||
Current NGramWorker only implement prompt lookup decoding,
|
||||
and in future we may also do RAG type drafter and other scenerios
|
||||
which don't rely on LLM model to give proposals.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Get local_rank/vocab_size from kwargs attribute
|
||||
self.local_rank = kwargs["local_rank"]
|
||||
self.vocab_size = kwargs["model_config"].get_vocab_size()
|
||||
|
||||
# Lazy initialization list.
|
||||
self._proposer: Top1Proposer
|
||||
|
||||
def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
|
||||
ngram_prompt_lookup_max: int):
|
||||
# Search valid candidate window between
|
||||
# ngram_prompt_lookup_min/ngram_prompt_lookup_max
|
||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
|
||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
|
||||
|
||||
def init_device(self):
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
self.load_model = lambda *args, **kwargs: None
|
||||
|
||||
# Current only support Top1Proposer
|
||||
self._proposer = Top1Proposer(
|
||||
self,
|
||||
device=self.device,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
# NGram don't need gpu sampler
|
||||
pass
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
) -> None:
|
||||
"""NGram doesn't depend on model execution, just pass this function"""
|
||||
pass
|
||||
|
||||
def determine_num_available_blocks(self) -> None:
|
||||
"""NGram doesn't depend on model execution, no need to check blocks"""
|
||||
pass
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""As there is no cache need to handle, just pass this function"""
|
||||
pass
|
||||
|
||||
def get_cache_block_size_bytes(self):
|
||||
"""Return the size of a cache block in bytes."""
|
||||
return 0
|
||||
|
||||
def sampler_output(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
sample_len: int,
|
||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||
"""NGram match algo to pick proposal candidate. Returns the list of
|
||||
sampler output, one per SequenceGroupMetadata.
|
||||
|
||||
For ngram worker, we already done needed transposed internal, so the
|
||||
indicator pass to sampler_output_to_torch shall be False.
|
||||
"""
|
||||
self._raise_if_unsupported(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
)
|
||||
|
||||
arr = []
|
||||
has_spec_out = False
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
|
||||
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_length = seq_data.get_len()
|
||||
|
||||
for ngram_size in range(
|
||||
min(self.ngram_prompt_lookup_max, input_length - 1),
|
||||
self.ngram_prompt_lookup_min,
|
||||
-1,
|
||||
):
|
||||
ngram_tensor = input_ids[-1 * ngram_size:]
|
||||
windows = input_ids.unfold(dimension=0,
|
||||
size=ngram_size,
|
||||
step=1)
|
||||
matches = (windows == ngram_tensor).all(dim=1)
|
||||
match_indices = matches.nonzero(as_tuple=True)[0]
|
||||
if match_indices.size()[0] > 1:
|
||||
has_spec_out = True
|
||||
res = seq_data.get_token_ids()
|
||||
res = res[match_indices[0] + ngram_size:match_indices[0] +
|
||||
ngram_size + sample_len]
|
||||
res_len = len(res)
|
||||
# pad 0 towards output as sample_len tokens required
|
||||
res += [0] * (sample_len - res_len)
|
||||
|
||||
break
|
||||
else:
|
||||
# if no candidate found, fill with 0
|
||||
res = [0] * sample_len
|
||||
|
||||
arr.append(res)
|
||||
|
||||
if not has_spec_out:
|
||||
return None, False
|
||||
|
||||
outputs = []
|
||||
token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device)
|
||||
indices = token_ids.unsqueeze(2)
|
||||
|
||||
token_probs = torch.zeros(
|
||||
(len(seq_group_metadata_list), sample_len, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
token_probs.scatter_(2, indices, 1)
|
||||
for i in range(len(seq_group_metadata_list)):
|
||||
outputs.append(
|
||||
SamplerOutput(
|
||||
outputs=None,
|
||||
sampled_token_probs=token_probs[i],
|
||||
sampled_token_ids=token_ids[i],
|
||||
))
|
||||
return outputs, False
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
max_proposal_len: int,
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_proposals(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
max_proposal_len,
|
||||
)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
"""NGramWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
|
||||
raise NotImplementedError(
|
||||
"NGramWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"NGramWorker does not support beam search.")
|
|
@ -12,6 +12,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
|||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
||||
|
@ -48,8 +49,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def from_workers(cls, proposer_worker: MultiStepWorker,
|
||||
scorer_worker: WorkerBase) -> "SpecDecodeWorker":
|
||||
def create_worker(
|
||||
cls,
|
||||
scorer_worker: WorkerBase,
|
||||
draft_worker_kwargs,
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
if "ngram_prompt_lookup_max" in draft_worker_kwargs:
|
||||
ngram_prompt_lookup_max = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||
ngram_prompt_lookup_min = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||
else:
|
||||
ngram_prompt_lookup_max = 0
|
||||
|
||||
if ngram_prompt_lookup_max > 0:
|
||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||
ngram_prompt_lookup_max)
|
||||
else:
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
|
||||
return SpecDecodeWorker(
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
|
@ -59,7 +79,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
proposer_worker: MultiStepWorker,
|
||||
proposer_worker: WorkerBase,
|
||||
scorer_worker: WorkerBase,
|
||||
rejection_sampler: RejectionSampler,
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
|
@ -134,8 +154,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
"""
|
||||
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
|
||||
) = True
|
||||
(self.proposer_worker.model_runner.model.sampler.
|
||||
include_gpu_probs_tensor) = True
|
||||
self.proposer_worker.set_include_gpu_probs_tensor()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of cache blocks to use.
|
||||
|
@ -183,8 +202,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
"speculative decoding "
|
||||
"requires non-None seq_group_metadata_list")
|
||||
|
||||
logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
|
||||
num_lookahead_slots)
|
||||
#logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
|
||||
# num_lookahead_slots)
|
||||
|
||||
# If no spec tokens, call the proposer and scorer workers normally.
|
||||
# Used for prefill.
|
||||
|
@ -216,7 +235,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
proposer and scorer model so that the KV cache is consistent between the
|
||||
two.
|
||||
"""
|
||||
logger.info("run proposer worker no spec")
|
||||
#logger.info("run proposer worker no spec")
|
||||
|
||||
self.proposer_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
|
@ -225,7 +244,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
logger.info("run target worker no spec")
|
||||
#logger.info("run target worker no spec")
|
||||
sampler_output = self.scorer_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
|
@ -259,7 +278,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
sequence.
|
||||
"""
|
||||
|
||||
logger.info("get spec proposals")
|
||||
#logger.info("get spec proposals")
|
||||
# Generate proposals using draft worker.
|
||||
assert blocks_to_swap_in is not None
|
||||
assert blocks_to_swap_out is not None
|
||||
|
@ -268,7 +287,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
||||
blocks_to_copy, k)
|
||||
|
||||
logger.info("score proposals")
|
||||
#logger.info("score proposals")
|
||||
proposal_scores = self.scorer.score_proposals(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
|
@ -278,11 +297,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
proposals,
|
||||
)
|
||||
|
||||
logger.info("verify proposals")
|
||||
#logger.info("verify proposals")
|
||||
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
|
||||
proposal_scores, proposals, k)
|
||||
|
||||
logger.info("create output list")
|
||||
#logger.info("create output list")
|
||||
return self._create_output_sampler_list(seq_group_metadata_list,
|
||||
accepted_token_ids, k)
|
||||
|
||||
|
|
|
@ -0,0 +1,200 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
|
||||
class Top1Proposer(SpeculativeProposer):
|
||||
"""Helper class which separates out sequences which would exceed the max
|
||||
model length when speculated upon.
|
||||
|
||||
This allows combinations of models such as JackFram/llama-68m draft with
|
||||
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
|
||||
2048 while Llama2-13b has max_position_embeddings of 4096.
|
||||
|
||||
We treat the sequences which exceed the proposal draft model length as
|
||||
"non-spec sequences". Essentially they skip the draft model and go through
|
||||
normal decoding in the target model.
|
||||
|
||||
Currently, only proposal_lens of 0 and k are supported, where k is a global
|
||||
batch proposal length. In the future vLLM should support per-sequence
|
||||
proposal lengths.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker: WorkerBase,
|
||||
device: str,
|
||||
vocab_size: int,
|
||||
max_proposal_len: Optional[int] = None,
|
||||
):
|
||||
self._worker = worker
|
||||
self._device = device
|
||||
self.max_proposal_len = max_proposal_len
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
def get_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
proposal_len: int,
|
||||
) -> SpeculativeProposals:
|
||||
"""Get speculative proposals given the input batch.
|
||||
|
||||
Sequences which would exceed the max model length are skipped during
|
||||
speculation.
|
||||
"""
|
||||
|
||||
# Split speculative- and non-speculative- sequences.
|
||||
(
|
||||
proposal_lens,
|
||||
nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices,
|
||||
) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len)
|
||||
|
||||
if nonzero_proposal_len_seqs:
|
||||
# Speculate tokens using the draft worker for the speculative
|
||||
# sequences.
|
||||
# If sampler_transposed is true, then maybe_sampler_output's
|
||||
# token_ids is like [batch] format in proposal_len size list,
|
||||
# while if it is false, the format would be [proposal_len]
|
||||
# in batch size list
|
||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
sample_len=proposal_len,
|
||||
)
|
||||
else:
|
||||
# If no sequences can be speculated, set sampler output to None.
|
||||
maybe_sampler_output = None
|
||||
transposed = False
|
||||
|
||||
# Combine speculative- and non-speculative sequences into the same
|
||||
# representation.
|
||||
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
|
||||
batch_size=len(seq_group_metadata_list),
|
||||
proposal_len=proposal_len,
|
||||
maybe_sampler_output=maybe_sampler_output,
|
||||
proposal_lens=proposal_lens,
|
||||
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
|
||||
sampler_transposed=transposed,
|
||||
)
|
||||
|
||||
proposals = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_tokens,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens,
|
||||
)
|
||||
|
||||
return proposals
|
||||
|
||||
def _split_by_max_model_len(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_len: int,
|
||||
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
|
||||
"""Determine which sequences would exceed the max model length."""
|
||||
|
||||
proposal_lens: List[int] = []
|
||||
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
||||
nonzero_proposal_len_indices: List[int] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
seq_len = seq_data.get_len()
|
||||
|
||||
# Currently only proposal lens of 0 or the global batch proposal len
|
||||
# are supported.
|
||||
# If max_proposal_len is defined, then we shall no exccess this
|
||||
# quota for nonzero_proposal
|
||||
if (self.max_proposal_len is None
|
||||
or seq_len + proposal_len < self.max_proposal_len):
|
||||
proposal_lens.append(proposal_len)
|
||||
nonzero_proposal_len_seqs.append(seq_group_metadata)
|
||||
nonzero_proposal_len_indices.append(i)
|
||||
else:
|
||||
proposal_lens.append(0)
|
||||
|
||||
return (
|
||||
proposal_lens,
|
||||
nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices,
|
||||
)
|
||||
|
||||
def _merge_outputs(
|
||||
self,
|
||||
batch_size: int,
|
||||
proposal_len: int,
|
||||
maybe_sampler_output: Optional[SamplerOutput],
|
||||
proposal_lens: List[int],
|
||||
nonzero_proposal_len_indices: List[int],
|
||||
sampler_transposed: bool,
|
||||
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
|
||||
"""After speculations are produced, merge the speculation results with
|
||||
the skipped sequences.
|
||||
"""
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty proposals.
|
||||
proposal_tokens = torch.full(
|
||||
size=(
|
||||
batch_size,
|
||||
proposal_len,
|
||||
),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device,
|
||||
)
|
||||
proposal_probs = torch.zeros(
|
||||
batch_size,
|
||||
proposal_len,
|
||||
self._vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=self._device,
|
||||
)
|
||||
proposal_lens_tensor = torch.zeros(len(proposal_lens),
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
proposal_tokens, proposal_probs = sampler_output_to_torch(
|
||||
sampler_output, sampler_transposed)
|
||||
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
||||
|
||||
entire_proposal_tokens = torch.full(
|
||||
size=(batch_size, *proposal_tokens.shape[1:]),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device,
|
||||
)
|
||||
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
||||
entire_proposal_probs = torch.zeros(
|
||||
batch_size,
|
||||
*proposal_probs.shape[1:],
|
||||
dtype=torch.float32,
|
||||
device=self._device,
|
||||
)
|
||||
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||
|
||||
proposal_tokens, proposal_probs = (
|
||||
entire_proposal_tokens,
|
||||
entire_proposal_probs,
|
||||
)
|
||||
|
||||
proposal_lens_tensor = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
|
||||
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
|
@ -49,10 +49,13 @@ def split_batch_by_proposal_len(
|
|||
|
||||
|
||||
def sampler_output_to_torch(
|
||||
sampler_output_list: List[SamplerOutput],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
sampler_output_list: List[SamplerOutput],
|
||||
sampler_transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Utility function which converts a list of SamplerOutput to tensors.
|
||||
|
||||
sampler_transposed here is used as the indicator for whether
|
||||
we need do additional tensor transpose logic here.
|
||||
|
||||
Returns:
|
||||
sampled_token_ids: torch.Tensor
|
||||
shape: [batch_size, len(sampler_output_list)]
|
||||
|
@ -68,7 +71,10 @@ def sampler_output_to_torch(
|
|||
for sampler_output in sampler_output_list
|
||||
],
|
||||
dim=0,
|
||||
).transpose(0, 1)
|
||||
)
|
||||
|
||||
if sampler_transposed:
|
||||
sampled_token_probs = sampled_token_probs.transpose(0, 1)
|
||||
|
||||
# shape: [batch_size, num_sampler_output]
|
||||
sampled_token_ids = torch.stack(
|
||||
|
@ -77,7 +83,9 @@ def sampler_output_to_torch(
|
|||
for sampler_output in sampler_output_list
|
||||
],
|
||||
dim=0,
|
||||
).transpose(0, 1)
|
||||
)
|
||||
if sampler_transposed:
|
||||
sampled_token_ids = sampled_token_ids.transpose(0, 1)
|
||||
|
||||
return sampled_token_ids, sampled_token_probs
|
||||
|
||||
|
|
Loading…
Reference in New Issue