mirror of https://github.com/vllm-project/vllm
[Misc][Refactor] Introduce ExecuteModelData (#4540)
This commit is contained in:
parent
344bf7cd2d
commit
bc8ad68455
|
@ -5,13 +5,12 @@ import pytest
|
|||
import torch
|
||||
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
from .utils import (assert_logprobs_dict_allclose, create_batch,
|
||||
create_execute_model_data,
|
||||
create_seq_group_metadata_from_prompts, create_worker,
|
||||
patch_execute_model_with_seeds, zero_kv_cache)
|
||||
|
||||
|
@ -105,31 +104,32 @@ def test_same_output_for_single_step():
|
|||
|
||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
multi_step_execute_model_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
single_step_execute_model_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
multi_step_seq_group = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
actual_output, _ = multi_step_worker.sampler_output(
|
||||
**multi_step_execute_model_data.to_dict(), sample_len=num_steps)
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=multi_step_seq_group),
|
||||
sample_len=num_steps)
|
||||
assert len(actual_output) == num_steps
|
||||
actual_output = actual_output[0]
|
||||
|
||||
single_step_seq_group = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
expected_output = worker.execute_model(
|
||||
**single_step_execute_model_data.to_dict(), )[0]
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=single_step_seq_group))[0]
|
||||
|
||||
actual_token_ids = [
|
||||
output.samples[0].output_token for output in actual_output
|
||||
|
@ -193,19 +193,20 @@ def test_same_output_for_multi_step():
|
|||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
|
||||
continuations = [[1] for _ in prompts]
|
||||
execute_model_data = create_execute_model_data(
|
||||
create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens), )
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
**execute_model_data.to_dict(), sample_len=num_steps)
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=num_steps)
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
|
@ -215,16 +216,16 @@ def test_same_output_for_multi_step():
|
|||
|
||||
for _ in multi_step_output:
|
||||
|
||||
execute_model_data = create_execute_model_data(
|
||||
create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
single_step_output.extend(
|
||||
worker.execute_model(**execute_model_data.to_dict(), ))
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
|
@ -304,12 +305,11 @@ def test_draft_proposals_full_speculation_len():
|
|||
) for _ in range(k)
|
||||
], True
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
proposal_len=k,
|
||||
)
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
@ -340,14 +340,13 @@ def test_draft_proposals_no_speculations():
|
|||
max_proposal_len=prompt_len + k - 1,
|
||||
)
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prompt_len=prompt_len)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prompt_len=prompt_len)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
proposal_len=k,
|
||||
)
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
@ -409,17 +408,16 @@ def test_draft_proposals_mixed_k():
|
|||
) for _ in range(k)
|
||||
], True
|
||||
|
||||
execute_model_data, _, _ = create_batch(
|
||||
seq_group_metadata_list, _, _ = create_batch(
|
||||
batch_size,
|
||||
k,
|
||||
prompt_len=prompt_len,
|
||||
prev_output_token_len=prev_output_token_len,
|
||||
)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
proposal_len=k,
|
||||
)
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import torch
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from .utils import (create_execute_model_data,
|
||||
create_seq_group_metadata_from_prompts, create_worker)
|
||||
from .utils import create_seq_group_metadata_from_prompts, create_worker
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_single_no_match():
|
||||
|
@ -44,17 +44,15 @@ def test_ngram_algo_correctness_for_single_no_match():
|
|||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
proposal_len=proposal_len,
|
||||
)
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
@ -113,17 +111,15 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
|||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
proposal_len=proposal_len,
|
||||
)
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
@ -185,17 +181,15 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
|||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
proposal_len=proposal_len,
|
||||
)
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
|
@ -15,8 +15,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
|||
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
||||
split_num_cache_blocks_evenly)
|
||||
|
||||
from .utils import (ExecuteModelData, create_batch, create_sampler_output_list,
|
||||
mock_worker)
|
||||
from .utils import create_batch, create_sampler_output_list, mock_worker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
|
@ -36,24 +35,19 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
|
|||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
call_args_list = draft_worker.get_spec_proposals.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
|
||||
for args, _ in call_args_list:
|
||||
(seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
||||
blocks_to_copy, actual_k) = args
|
||||
actual_execute_model_data = ExecuteModelData(seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy)
|
||||
assert actual_execute_model_data == execute_model_data
|
||||
assert actual_k == k
|
||||
actual_execute_model_data = args[0]
|
||||
assert actual_execute_model_data == execute_model_req
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
|
@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
|||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
execute_model_data, prompts, prev_output_tokens = create_batch(
|
||||
seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
|
||||
batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
|
@ -105,20 +99,20 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
|||
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
seen_contexts = []
|
||||
|
||||
call_args_list = target_worker.execute_model.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
for args, kwargs in call_args_list:
|
||||
target_execute_model_data = ExecuteModelData.from_dict(kwargs)
|
||||
for _, kwargs in call_args_list:
|
||||
seq_group_metadata_list = kwargs[
|
||||
"execute_model_req"].seq_group_metadata_list
|
||||
|
||||
assert len(target_execute_model_data.seq_group_metadata_list) == (
|
||||
k + 1) * batch_size
|
||||
for seq_group_metadata in (
|
||||
target_execute_model_data.seq_group_metadata_list):
|
||||
assert len(seq_group_metadata_list) == (k + 1) * batch_size
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seen_contexts.append(seq_data.get_token_ids())
|
||||
|
||||
|
@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
|||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
|
@ -207,8 +201,9 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
|||
rejection_sampler.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
assert len(rejection_sampler.call_args_list) == 1
|
||||
_, kwargs = rejection_sampler.call_args_list[0]
|
||||
|
@ -262,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
|||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
|
@ -302,8 +297,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
|||
|
||||
rejection_sampler.return_value = rejection_sampler_output
|
||||
|
||||
output = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
expected_output = create_sampler_output_list(
|
||||
token_ids=rejection_sampler_output.transpose(0, 1),
|
||||
|
@ -312,7 +308,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
|||
|
||||
seq_ids = [
|
||||
next(iter(seq_group_metadata.seq_data.keys()))
|
||||
for seq_group_metadata in execute_model_data.seq_group_metadata_list
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
]
|
||||
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
||||
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
||||
|
@ -383,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
|||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||
device='cuda') * k
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
|
@ -428,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
|||
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
|
||||
mock_rejsample_metrics)
|
||||
|
||||
output = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
|
||||
|
||||
call_args_list = (
|
||||
|
@ -462,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int):
|
|||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
|
||||
execute_model_data, prompts, prev_output_tokens = create_batch(
|
||||
batch_size, k, prev_output_token_len=0)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prev_output_token_len=0)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
out = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||
assert out[
|
||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict())
|
||||
target_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict())
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [0, 5])
|
||||
|
@ -503,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int):
|
|||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
|
||||
execute_model_data, prompts, prev_output_tokens = create_batch(
|
||||
batch_size, k, prev_output_token_len=0)
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prev_output_token_len=0)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
out = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||
assert out[
|
||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict())
|
||||
target_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict())
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from dataclasses import dataclass, fields
|
||||
from itertools import count
|
||||
from typing import Dict, Iterable, List, Optional, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
@ -16,50 +15,10 @@ from vllm.worker.cache_engine import CacheEngine
|
|||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteModelData:
|
||||
"""Helper data structure which facilitates cleaner tests.
|
||||
"""
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
blocks_to_swap_in: Dict[int, int]
|
||||
blocks_to_swap_out: Dict[int, int]
|
||||
blocks_to_copy: Dict[int, List[int]]
|
||||
|
||||
def to_dict(self):
|
||||
return dict(
|
||||
(field.name, getattr(self, field.name)) for field in fields(self))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
cleaned = dict((field.name, d[field.name]) for field in fields(cls))
|
||||
return cls(**cleaned)
|
||||
|
||||
|
||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
|
||||
|
||||
def create_execute_model_data(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
||||
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
||||
blocks_to_copy: Optional[Dict[int, int]] = None,
|
||||
) -> ExecuteModelData:
|
||||
if blocks_to_swap_in is None:
|
||||
blocks_to_swap_in = {}
|
||||
if blocks_to_swap_out is None:
|
||||
blocks_to_swap_out = {}
|
||||
if blocks_to_copy is None:
|
||||
blocks_to_copy = {}
|
||||
|
||||
return ExecuteModelData(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
|
||||
def mock_worker(cls=None,
|
||||
vocab_size: int = 30_000,
|
||||
max_model_len: int = 2048,
|
||||
|
@ -258,8 +217,7 @@ def create_batch(batch_size,
|
|||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
execute_model_data = create_execute_model_data(
|
||||
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
|
||||
block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids), )
|
||||
return execute_model_data, prompts, prev_output_tokens
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids)
|
||||
return seq_group_metadata_list, prompts, prev_output_tokens
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
@ -54,10 +55,14 @@ def test_swap() -> None:
|
|||
|
||||
# Test swap out.
|
||||
blocks_to_swap_out = {3: 72, 56: 35, 84: 34}
|
||||
worker.execute_model(seq_group_metadata_list=[],
|
||||
blocks_to_swap_in={},
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy={})
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=[],
|
||||
blocks_to_swap_in={},
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy={},
|
||||
)
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
for i in range(num_layers):
|
||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||
|
@ -66,14 +71,19 @@ def test_swap() -> None:
|
|||
assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
|
||||
|
||||
# Test swap in.
|
||||
blocks_to_swap_in = {19: 45, 67: 23, 12: 78, 40: 99, 1: 71}
|
||||
worker.execute_model(seq_group_metadata_list=[],
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out={},
|
||||
blocks_to_copy={})
|
||||
execute_model_req.blocks_to_swap_out = {}
|
||||
execute_model_req.blocks_to_swap_in = {
|
||||
19: 45,
|
||||
67: 23,
|
||||
12: 78,
|
||||
40: 99,
|
||||
1: 71
|
||||
}
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
for i in range(num_layers):
|
||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||
for src, dst in blocks_to_swap_in.items():
|
||||
for src, dst in execute_model_req.blocks_to_swap_in.items():
|
||||
assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
|
||||
assert allclose(gpu_value_cache[dst], cpu_value_cache[src])
|
||||
|
|
|
@ -128,6 +128,8 @@ class SchedulerOutputs:
|
|||
ignored_seq_groups: List[SequenceGroup]
|
||||
# The number of slots for lookahead decoding.
|
||||
num_lookahead_slots: int
|
||||
# The number of requests in the running queue
|
||||
running_queue_size: int
|
||||
|
||||
def __post_init__(self):
|
||||
# Swap in and swap out should never happen at the same time.
|
||||
|
@ -797,6 +799,7 @@ class Scheduler:
|
|||
ignored_seq_groups=prefills.ignored_seq_groups +
|
||||
swapped_in.infeasible_seq_groups,
|
||||
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
||||
running_queue_size=len(self.running),
|
||||
)
|
||||
|
||||
def _schedule_chunked_prefill(self):
|
||||
|
@ -883,6 +886,7 @@ class Scheduler:
|
|||
swapped_in.blocks_to_copy),
|
||||
ignored_seq_groups=prefills.ignored_seq_groups,
|
||||
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
||||
running_queue_size=len(self.running),
|
||||
)
|
||||
|
||||
def _schedule(self) -> SchedulerOutputs:
|
||||
|
|
|
@ -16,7 +16,7 @@ from vllm.logger import init_logger
|
|||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import MultiModalData, SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
@ -210,12 +210,16 @@ class _AsyncLLMEngine(LLMEngine):
|
|||
|
||||
if not scheduler_outputs.is_empty():
|
||||
# Execute the model.
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
||||
running_queue_size=scheduler_outputs.running_queue_size,
|
||||
)
|
||||
output = await self.model_executor.execute_model_async(
|
||||
seq_group_metadata_list,
|
||||
scheduler_outputs.blocks_to_swap_in,
|
||||
scheduler_outputs.blocks_to_swap_out,
|
||||
scheduler_outputs.blocks_to_copy,
|
||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
|
||||
execute_model_req)
|
||||
else:
|
||||
output = []
|
||||
|
||||
|
|
|
@ -22,8 +22,8 @@ from vllm.logger import init_logger
|
|||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
||||
SequenceGroup, SequenceGroupMetadata,
|
||||
from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput,
|
||||
Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
|
@ -583,12 +583,16 @@ class LLMEngine:
|
|||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
output = self.model_executor.execute_model(
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
|
||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
||||
running_queue_size=scheduler_outputs.running_queue_size,
|
||||
)
|
||||
output = self.model_executor.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
else:
|
||||
output = []
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, List, Set, Tuple
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -7,7 +7,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
|||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
|
||||
|
@ -72,18 +72,10 @@ class CPUExecutor(ExecutorBase):
|
|||
logger.info("# CPU blocks: %d", num_gpu_blocks)
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def execute_model(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int) -> List[SamplerOutput]:
|
||||
output = self.driver_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
|
@ -104,19 +96,10 @@ class CPUExecutor(ExecutorBase):
|
|||
class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int,
|
||||
) -> List[SamplerOutput]:
|
||||
output = await make_async(self.driver_worker.execute_model)(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_lookahead_slots=num_lookahead_slots)
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
output = await make_async(self.driver_worker.execute_model
|
||||
)(execute_model_req=execute_model_req, )
|
||||
return output
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, VisionLanguageConfig)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
|
||||
|
||||
class ExecutorBase(ABC):
|
||||
|
@ -68,12 +68,9 @@ class ExecutorBase(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def execute_model(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int) -> List[SamplerOutput]:
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
"""Executes at least one model step on the given sequences."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -107,13 +104,8 @@ class ExecutorAsyncBase(ExecutorBase):
|
|||
|
||||
@abstractmethod
|
||||
async def execute_model_async(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int,
|
||||
) -> List[SamplerOutput]:
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
"""Executes one model step on the given sequences."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple
|
|||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
@ -117,20 +117,9 @@ class GPUExecutor(ExecutorBase):
|
|||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int,
|
||||
) -> List[SamplerOutput]:
|
||||
output = self.driver_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
)
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
|
@ -154,16 +143,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
|
|||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> List[SamplerOutput]:
|
||||
output = await make_async(self.driver_worker.execute_model)(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_lookahead_slots=num_lookahead_slots)
|
||||
output = await make_async(self.driver_worker.execute_model
|
||||
)(execute_model_req=execute_model_req, )
|
||||
return output
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from typing import Dict, List, Set, Tuple
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.utils import make_async
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
@ -45,20 +45,18 @@ class NeuronExecutor(ExecutorBase):
|
|||
"""
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def execute_model(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int) -> List[SamplerOutput]:
|
||||
assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
|
||||
and blocks_to_copy == {}), (
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
assert (execute_model_req.blocks_to_swap_in == {}
|
||||
and execute_model_req.blocks_to_swap_out == {}
|
||||
and execute_model_req.blocks_to_copy == {}), (
|
||||
"Cache operations are not supported for Neuron backend.")
|
||||
assert num_lookahead_slots == 0, (
|
||||
assert execute_model_req.num_lookahead_slots == 0, (
|
||||
"lookahead not supported for Neuron backend.")
|
||||
|
||||
output = self.driver_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list)
|
||||
execute_model_req.seq_group_metadata_list)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
|
@ -80,14 +78,11 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
|
|||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> List[SamplerOutput]:
|
||||
output = await make_async(self.driver_worker.execute_model)(
|
||||
seq_group_metadata_list=seq_group_metadata_list, )
|
||||
output = await make_async(
|
||||
self.driver_worker.execute_model
|
||||
)(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, )
|
||||
return output
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
|
|
|
@ -10,7 +10,7 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
|||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
get_vllm_instance_id, make_async)
|
||||
|
||||
|
@ -166,21 +166,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
def execute_model(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int = 0) -> List[SamplerOutput]:
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
all_outputs = self._run_workers(
|
||||
"execute_model",
|
||||
driver_kwargs={
|
||||
"seq_group_metadata_list": seq_group_metadata_list,
|
||||
"blocks_to_swap_in": blocks_to_swap_in,
|
||||
"blocks_to_swap_out": blocks_to_swap_out,
|
||||
"blocks_to_copy": blocks_to_copy,
|
||||
"num_lookahead_slots": num_lookahead_slots,
|
||||
},
|
||||
driver_kwargs={"execute_model_req": execute_model_req},
|
||||
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Sequence and its related classes."""
|
||||
import copy
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
from vllm.block import LogicalTokenBlock
|
||||
|
@ -734,3 +734,33 @@ class SamplerOutput:
|
|||
f"sampled_token_probs={sampled_token_probs_repr}, "
|
||||
f"sampled_token_ids={sampled_token_ids_repr}, "
|
||||
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteModelRequest:
|
||||
"""The model execution request."""
|
||||
# The sequence group metadata list.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
# Blocks to swap in. Dict of CPU -> GPU block number.
|
||||
blocks_to_swap_in: Dict[int, int] = field(default_factory=dict)
|
||||
# Blocks to swap out. Dict of GPU -> CPU block number.
|
||||
blocks_to_swap_out: Dict[int, int] = field(default_factory=dict)
|
||||
# Blocks to copy. Source to a list of dest blocks.
|
||||
blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict)
|
||||
# The number of slots for lookahead decoding.
|
||||
num_lookahead_slots: int = 0
|
||||
# The number of requests in the running queue.
|
||||
running_queue_size: int = 0
|
||||
|
||||
def clone(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
) -> "ExecuteModelRequest":
|
||||
"""Clone the request with a new sequence group metadata list."""
|
||||
return ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=self.blocks_to_swap_in.copy(),
|
||||
blocks_to_swap_out=self.blocks_to_swap_out.copy(),
|
||||
blocks_to_copy=self.blocks_to_copy.copy(),
|
||||
num_lookahead_slots=self.num_lookahead_slots,
|
||||
running_queue_size=self.running_queue_size,
|
||||
)
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from itertools import chain, count
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
from typing import Iterator, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||
|
@ -40,11 +41,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
||||
def score_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
k: int,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
"""Score the proposed tokens via the scorer model.
|
||||
|
@ -57,11 +54,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||
no speculation is produced for that sequence.
|
||||
|
||||
Args:
|
||||
seq_group_metadata_list: The input sequence group metadata.
|
||||
blocks_to_swap_in: This is passed to the worker during scoring.
|
||||
blocks_to_swap_out: This is passed to the worker during scoring.
|
||||
blocks_to_copy: This is passed to the worker during scoring.
|
||||
k: The fixed proposal length.
|
||||
execute_model_req: The execution request.
|
||||
proposals: The speculative proposals to score.
|
||||
Returns:
|
||||
SpeculativeScores: The scores of each speculative token, along with
|
||||
|
@ -80,28 +73,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||
|
||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens) = self._expand_batch(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list_without_skips,
|
||||
proposal_lens_list=proposal_lens_list,
|
||||
)
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
execute_model_req=execute_model_req.clone(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list, ))
|
||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
all_tokens, all_probs, spec_logprobs = self._contract_batch(
|
||||
contracted_bs=len(seq_group_metadata_list),
|
||||
contracted_bs=len(execute_model_req.seq_group_metadata_list),
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
num_scoring_tokens=num_scoring_tokens,
|
||||
non_spec_indices=non_spec_indices,
|
||||
spec_indices=spec_indices,
|
||||
k=k,
|
||||
k=execute_model_req.num_lookahead_slots,
|
||||
)
|
||||
|
||||
return SpeculativeScores(
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -58,11 +57,7 @@ class SpeculativeProposer(ABC):
|
|||
@abstractmethod
|
||||
def get_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
max_proposal_len: int,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -72,11 +67,7 @@ class SpeculativeScorer(ABC):
|
|||
@abstractmethod
|
||||
def score_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
k: int,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import copy
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
|
@ -44,10 +45,7 @@ class MultiStepWorker(Worker):
|
|||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
|
@ -57,26 +55,24 @@ class MultiStepWorker(Worker):
|
|||
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
|
||||
blocks_to_swap_out, blocks_to_copy)
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
# Shallow copy input data so modifications (such as appending tokens)
|
||||
# do not cause side-effects.
|
||||
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
||||
seq_group_metadata_list)
|
||||
execute_model_req.seq_group_metadata_list)
|
||||
copied_execute_model_req = execute_model_req.clone(
|
||||
copied_seq_group_metadata_list)
|
||||
|
||||
# Assert enough KV space for sample_len tokens per sequence.
|
||||
self._assert_enough_kv_space(seq_group_metadata_list, sample_len)
|
||||
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
|
||||
sample_len)
|
||||
|
||||
# Run model sample_len times.
|
||||
model_outputs = []
|
||||
for _ in range(sample_len):
|
||||
model_output = super().execute_model(
|
||||
seq_group_metadata_list=copied_seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
execute_model_req=copied_execute_model_req)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
|
@ -89,23 +85,13 @@ class MultiStepWorker(Worker):
|
|||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
max_proposal_len: int,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_proposals(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
max_proposal_len,
|
||||
)
|
||||
return self._proposer.get_proposals(execute_model_req)
|
||||
|
||||
def _append_new_tokens(
|
||||
self, model_output: SamplerOutput,
|
||||
|
@ -196,20 +182,22 @@ class MultiStepWorker(Worker):
|
|||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> None:
|
||||
"""MultiStepWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
|
||||
if any([
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
execute_model_req.blocks_to_copy
|
||||
]):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in seq_group_metadata_list):
|
||||
for seq_group_metadata in
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support beam search.")
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
|
@ -46,13 +46,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||
# NGram don't need gpu sampler
|
||||
pass
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
) -> None:
|
||||
def execute_model(self, execute_model_req: ExecuteModelRequest) -> None:
|
||||
"""NGram doesn't depend on model execution, just pass this function"""
|
||||
pass
|
||||
|
||||
|
@ -71,10 +65,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||
|
||||
def sampler_output(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||
"""NGram match algo to pick proposal candidate. Returns the list of
|
||||
|
@ -83,16 +74,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||
For ngram worker, we already done needed transposed internal, so the
|
||||
indicator pass to sampler_output_to_torch shall be False.
|
||||
"""
|
||||
self._raise_if_unsupported(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
)
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
arr = []
|
||||
has_spec_out = False
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
|
||||
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
||||
|
@ -135,17 +121,19 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||
indices = token_ids.unsqueeze(2)
|
||||
|
||||
token_probs = torch.zeros(
|
||||
(len(seq_group_metadata_list), sample_len, self.vocab_size),
|
||||
(len(execute_model_req.seq_group_metadata_list), sample_len,
|
||||
self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
token_probs.scatter_(2, indices, 1)
|
||||
token_logprobs = torch.zeros(
|
||||
(len(seq_group_metadata_list), sample_len, self.vocab_size),
|
||||
(len(execute_model_req.seq_group_metadata_list), sample_len,
|
||||
self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
for i in range(len(seq_group_metadata_list)):
|
||||
for i in range(len(execute_model_req.seq_group_metadata_list)):
|
||||
outputs.append(
|
||||
SamplerOutput(
|
||||
outputs=None,
|
||||
|
@ -157,40 +145,32 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
max_proposal_len: int,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_proposals(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
max_proposal_len,
|
||||
)
|
||||
return self._proposer.get_proposals(execute_model_req)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> None:
|
||||
"""NGramWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
|
||||
if any([
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
execute_model_req.blocks_to_copy
|
||||
]):
|
||||
raise NotImplementedError(
|
||||
"NGramWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in seq_group_metadata_list):
|
||||
for seq_group_metadata in
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"NGramWorker does not support beam search.")
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from functools import cached_property
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
|
@ -189,69 +190,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
num_lookahead_slots: int,
|
||||
) -> List[SamplerOutput]:
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
"""Perform speculative decoding on the input batch.
|
||||
"""
|
||||
|
||||
assert seq_group_metadata_list is not None, (
|
||||
assert execute_model_req.seq_group_metadata_list is not None, (
|
||||
"speculative decoding "
|
||||
"requires non-None seq_group_metadata_list")
|
||||
|
||||
#logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
|
||||
# num_lookahead_slots)
|
||||
|
||||
# If no spec tokens, call the proposer and scorer workers normally.
|
||||
# Used for prefill.
|
||||
if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0:
|
||||
return self._run_no_spec(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
if execute_model_req.num_lookahead_slots == 0 or len(
|
||||
execute_model_req.seq_group_metadata_list) == 0:
|
||||
return self._run_no_spec(execute_model_req)
|
||||
|
||||
return self._run_speculative_decoding_step(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
k=num_lookahead_slots,
|
||||
)
|
||||
return self._run_speculative_decoding_step(execute_model_req)
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
def _run_no_spec(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
) -> List[SamplerOutput]:
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
"""Run a prefill step, without any speculation. The input is sent to the
|
||||
proposer and scorer model so that the KV cache is consistent between the
|
||||
two.
|
||||
"""
|
||||
#logger.info("run proposer worker no spec")
|
||||
|
||||
self.proposer_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
|
||||
#logger.info("run target worker no spec")
|
||||
sampler_output = self.scorer_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
||||
assert len(sampler_output) == 1
|
||||
sampler_output = sampler_output[0]
|
||||
|
||||
|
@ -264,13 +233,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
|
||||
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
||||
def _run_speculative_decoding_step(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
k: int,
|
||||
) -> List[SamplerOutput]:
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
"""Execute a single step of speculative decoding.
|
||||
|
||||
This invokes the proposer worker to get k speculative tokens for each
|
||||
|
@ -282,33 +246,25 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
|
||||
#logger.info("get spec proposals")
|
||||
# Generate proposals using draft worker.
|
||||
assert blocks_to_swap_in is not None
|
||||
assert blocks_to_swap_out is not None
|
||||
assert blocks_to_copy is not None
|
||||
proposals = self.proposer_worker.get_spec_proposals(
|
||||
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
||||
blocks_to_copy, k)
|
||||
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
||||
|
||||
#logger.info("score proposals")
|
||||
proposal_scores = self.scorer.score_proposals(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
k,
|
||||
execute_model_req,
|
||||
proposals,
|
||||
)
|
||||
|
||||
#logger.info("verify proposals")
|
||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||
seq_group_metadata_list, proposal_scores, proposals, k)
|
||||
execute_model_req.seq_group_metadata_list, proposal_scores,
|
||||
proposals, execute_model_req.num_lookahead_slots)
|
||||
|
||||
#logger.info("create output list")
|
||||
return self._create_output_sampler_list(
|
||||
seq_group_metadata_list,
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
accepted_token_ids,
|
||||
target_logprobs=target_logprobs,
|
||||
k=k)
|
||||
k=execute_model_req.num_lookahead_slots)
|
||||
|
||||
@nvtx_range("spec_decode_worker._verify_tokens")
|
||||
def _verify_tokens(
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
|
@ -40,17 +41,15 @@ class Top1Proposer(SpeculativeProposer):
|
|||
|
||||
def get_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
proposal_len: int,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
"""Get speculative proposals given the input batch.
|
||||
|
||||
Sequences which would exceed the max model length are skipped during
|
||||
speculation.
|
||||
"""
|
||||
proposal_len = execute_model_req.num_lookahead_slots
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
# Split speculative- and non-speculative- sequences.
|
||||
(
|
||||
|
@ -66,11 +65,12 @@ class Top1Proposer(SpeculativeProposer):
|
|||
# token_ids is like [batch] format in proposal_len size list,
|
||||
# while if it is false, the format would be [proposal_len]
|
||||
# in batch size list
|
||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||
nonzero_execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_lookahead_slots=proposal_len,
|
||||
)
|
||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||
execute_model_req=nonzero_execute_model_req,
|
||||
sample_len=proposal_len,
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -13,7 +13,7 @@ from vllm.distributed import (broadcast_tensor_dict,
|
|||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
|
@ -256,22 +256,24 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
|||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
|
||||
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
||||
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> List[SamplerOutput]:
|
||||
|
||||
if execute_model_req is None:
|
||||
seq_group_metadata_list = None
|
||||
else:
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
if self.is_driver_worker:
|
||||
assert seq_group_metadata_list is not None
|
||||
num_seq_groups: int = len(seq_group_metadata_list)
|
||||
assert blocks_to_swap_in is not None
|
||||
assert blocks_to_swap_out is not None
|
||||
assert blocks_to_copy is not None
|
||||
assert len(blocks_to_swap_in) == 0
|
||||
assert len(blocks_to_swap_out) == 0
|
||||
assert execute_model_req is not None
|
||||
blocks_to_copy = execute_model_req.blocks_to_copy
|
||||
assert len(execute_model_req.blocks_to_swap_in) == 0
|
||||
assert len(execute_model_req.blocks_to_swap_out) == 0
|
||||
data: Dict[str, Any] = {
|
||||
"num_seq_groups": num_seq_groups,
|
||||
"blocks_to_copy": blocks_to_copy,
|
||||
"blocks_to_copy": execute_model_req.blocks_to_copy,
|
||||
}
|
||||
broadcast_tensor_dict(data, src=0)
|
||||
else:
|
||||
|
@ -279,7 +281,6 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
|||
num_seq_groups = data["num_seq_groups"]
|
||||
blocks_to_copy = data["blocks_to_copy"]
|
||||
|
||||
assert blocks_to_copy is not None
|
||||
self.cache_copy(blocks_to_copy)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
|
|
|
@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.custom_all_reduce import (
|
|||
init_custom_ar)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
@ -211,19 +211,21 @@ class Worker(WorkerBase):
|
|||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
|
||||
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
||||
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
|
||||
num_lookahead_slots: int = 0,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
|
||||
if execute_model_req is None:
|
||||
seq_group_metadata_list = None
|
||||
else:
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
if self.is_driver_worker:
|
||||
assert seq_group_metadata_list is not None
|
||||
assert execute_model_req is not None
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
assert blocks_to_swap_in is not None
|
||||
assert blocks_to_swap_out is not None
|
||||
assert blocks_to_copy is not None
|
||||
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
|
||||
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
|
||||
blocks_to_copy = execute_model_req.blocks_to_copy
|
||||
data: Dict[str, Any] = {
|
||||
"num_seq_groups": num_seq_groups,
|
||||
"blocks_to_swap_in": blocks_to_swap_in,
|
||||
|
@ -238,9 +240,6 @@ class Worker(WorkerBase):
|
|||
blocks_to_swap_out = data["blocks_to_swap_out"]
|
||||
blocks_to_copy = data["blocks_to_copy"]
|
||||
|
||||
assert blocks_to_swap_in is not None
|
||||
assert blocks_to_swap_out is not None
|
||||
assert blocks_to_copy is not None
|
||||
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Dict, List, Set, Tuple
|
|||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||
update_environment_variables)
|
||||
|
||||
|
@ -48,10 +48,8 @@ class WorkerBase(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def execute_model(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
|
||||
int],
|
||||
blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
"""Executes at least one model step on the given sequences, unless no
|
||||
sequences are provided."""
|
||||
raise NotImplementedError
|
||||
|
|
Loading…
Reference in New Issue