mirror of https://github.com/vllm-project/vllm
[Spec Decoding] Streamline batch expansion tensor manipulation (#7851)
This commit is contained in:
parent
70c094ade6
commit
1856aff4d6
|
@ -55,10 +55,9 @@ def fake_sequence_group_metadata():
|
|||
|
||||
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 1, 0]
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
fake_sequence_group_metadata,
|
||||
proposal_lens,
|
||||
select_proposal_len_zero=True)
|
||||
_, (filtered_groups,
|
||||
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
expected_groups = [
|
||||
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
|
||||
|
@ -71,10 +70,9 @@ def test_filter_zero_length_proposals(fake_sequence_group_metadata):
|
|||
|
||||
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 1, 2]
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
fake_sequence_group_metadata,
|
||||
proposal_lens,
|
||||
select_proposal_len_zero=False)
|
||||
(filtered_groups,
|
||||
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
expected_groups = [
|
||||
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
|
||||
|
@ -86,8 +84,7 @@ def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
|
|||
|
||||
|
||||
def test_empty_inputs():
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
[], [], select_proposal_len_zero=True)
|
||||
_, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
@ -95,10 +92,9 @@ def test_empty_inputs():
|
|||
|
||||
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 0, 0]
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
fake_sequence_group_metadata,
|
||||
proposal_lens,
|
||||
select_proposal_len_zero=False)
|
||||
(filtered_groups,
|
||||
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
@ -106,10 +102,9 @@ def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
|
|||
|
||||
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
|
||||
proposal_lens = [1, 1, 1]
|
||||
filtered_groups, indices = split_batch_by_proposal_len(
|
||||
fake_sequence_group_metadata,
|
||||
proposal_lens,
|
||||
select_proposal_len_zero=True)
|
||||
_, (filtered_groups,
|
||||
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
|
|
@ -10,8 +10,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
|
|||
get_all_seq_ids)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
SeqId = int
|
||||
|
@ -88,17 +87,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
(all_tokens, all_probs, spec_logprobs,
|
||||
all_hidden_states) = self._contract_batch(
|
||||
contracted_bs=len(execute_model_req.seq_group_metadata_list),
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
num_scoring_tokens=num_scoring_tokens,
|
||||
non_spec_indices=non_spec_indices,
|
||||
spec_indices=spec_indices,
|
||||
k=execute_model_req.num_lookahead_slots,
|
||||
)
|
||||
if not non_spec_indices:
|
||||
# All sequence groups in batch have spec decoding enabled
|
||||
contracted = self._contract_batch_all_spec(
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
)
|
||||
else:
|
||||
# Batch has a mix of spec decode enabled and disabled seq groups
|
||||
contracted = self._contract_batch(
|
||||
contracted_bs=len(execute_model_req.seq_group_metadata_list),
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
num_scoring_tokens=num_scoring_tokens,
|
||||
non_spec_indices=non_spec_indices,
|
||||
spec_indices=spec_indices,
|
||||
k=execute_model_req.num_lookahead_slots,
|
||||
)
|
||||
|
||||
all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted
|
||||
return SpeculativeScores(
|
||||
probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
|
@ -121,14 +128,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
spec_seqs, spec_indices = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list,
|
||||
proposal_lens_list,
|
||||
select_proposal_len_zero=False)
|
||||
non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list,
|
||||
proposal_lens_list,
|
||||
select_proposal_len_zero=True)
|
||||
(spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
|
||||
split_batch_by_proposal_len(
|
||||
seq_group_metadata_list, proposal_lens_list)
|
||||
|
||||
target_seq_group_metadata_list = self._create_scoring_model_input(
|
||||
seq_group_metadata_list=spec_seqs,
|
||||
|
@ -171,7 +173,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||
# The number of tokens in the expanded batch used for speculation is
|
||||
# equal to the total expanded batch size minus the number of samples for
|
||||
# non-speculative sequences.
|
||||
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
|
||||
non_spec_expanded_bs = len(non_spec_target_token_ids)
|
||||
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
||||
|
||||
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
|
||||
|
@ -181,7 +183,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||
|
||||
if target_hidden_states is not None:
|
||||
target_hidden_states = target_hidden_states.reshape(
|
||||
spec_expanded_bs, k + 1, target_hidden_states.shape[-1])
|
||||
*target_token_ids.shape, target_hidden_states.shape[-1])
|
||||
|
||||
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
|
||||
fill_value=-1)
|
||||
|
@ -196,24 +198,58 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||
all_hidden_states = None
|
||||
|
||||
if non_spec_indices:
|
||||
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
|
||||
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
|
||||
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
|
||||
|
||||
all_tokens[non_spec_indices, :1] = \
|
||||
non_spec_target_token_ids.unsqueeze(1)
|
||||
all_probs[non_spec_indices, :1, :] = \
|
||||
non_spec_target_probs.unsqueeze(1)
|
||||
all_logprobs[non_spec_indices, :1, :] = \
|
||||
non_spec_target_logprobs.unsqueeze(1)
|
||||
if all_hidden_states is not None:
|
||||
all_hidden_states[
|
||||
non_spec_indices, :1, :] = non_spec_target_hidden_states
|
||||
assert non_spec_target_hidden_states is not None
|
||||
all_hidden_states[non_spec_indices, :1, :] = \
|
||||
non_spec_target_hidden_states.unsqueeze(1)
|
||||
|
||||
if spec_indices:
|
||||
all_tokens[spec_indices] = target_token_ids
|
||||
all_probs[spec_indices] = target_probs
|
||||
all_logprobs[spec_indices] = target_logprobs
|
||||
|
||||
if all_hidden_states is not None:
|
||||
all_hidden_states[spec_indices] = target_hidden_states
|
||||
|
||||
return all_tokens, all_probs, all_logprobs, all_hidden_states
|
||||
|
||||
def _contract_batch_all_spec(
|
||||
self,
|
||||
target_sampler_output: SamplerOutput,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor]]:
|
||||
"""Contract the expanded batch back into its original size.
|
||||
This maps the scores of speculative tokens back to their original
|
||||
sequences.
|
||||
|
||||
It assumes all sequences in the batch were previously expanded.
|
||||
"""
|
||||
|
||||
# Map distinct sequences used to score each token
|
||||
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
|
||||
contracted_bs, k = proposals.proposal_token_ids.shape
|
||||
|
||||
# Reshape tensors to original batch size
|
||||
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
|
||||
contracted_bs, k + 1)
|
||||
target_probs = target_sampler_output.sampled_token_probs.reshape(
|
||||
*target_token_ids.shape, self._vocab_size)
|
||||
target_logprobs = target_sampler_output.logprobs.reshape(
|
||||
target_probs.shape)
|
||||
target_hidden_states = target_sampler_output.hidden_states
|
||||
if target_hidden_states is not None:
|
||||
target_hidden_states = target_hidden_states.reshape(
|
||||
*target_token_ids.shape, target_hidden_states.shape[-1])
|
||||
|
||||
return (target_token_ids, target_probs, target_logprobs,
|
||||
target_hidden_states)
|
||||
|
||||
def _create_scoring_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
|
@ -345,8 +381,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||
token_chunk_size=1,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _split_scoring_output(
|
||||
self, sampler_output: SamplerOutput, num_scoring_tokens: int
|
||||
sampler_output: SamplerOutput, num_scoring_tokens: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
@ -361,10 +398,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||
#
|
||||
# First samples are from speculative scoring, latter samples are non-
|
||||
# speculative samples.
|
||||
split_sizes = [
|
||||
num_scoring_tokens,
|
||||
sampler_output.sampled_token_ids.numel() - num_scoring_tokens
|
||||
]
|
||||
split_sizes = (num_scoring_tokens,
|
||||
sampler_output.sampled_token_ids.numel() -
|
||||
num_scoring_tokens)
|
||||
(spec_probs, non_spec_probs
|
||||
) = sampler_output.sampled_token_probs.split(split_sizes)
|
||||
(spec_sampled_tokens, non_spec_sampled_tokens
|
||||
|
@ -382,32 +418,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||
else:
|
||||
spec_hidden_states, non_spec_hidden_states = None, None
|
||||
|
||||
# Convert scores to tensors.
|
||||
sampler_output.sampled_token_probs = spec_probs
|
||||
sampler_output.sampled_token_ids = spec_sampled_tokens
|
||||
sampler_output.logprobs = spec_logprobs
|
||||
sampler_output.hidden_states = spec_hidden_states
|
||||
(target_token_ids, target_probs, target_logprobs,
|
||||
target_hidden_states) = sampler_output_to_torch([sampler_output],
|
||||
True)
|
||||
|
||||
# Convert non-speculative output tokens to tensors.
|
||||
sampler_output.sampled_token_probs = non_spec_probs
|
||||
sampler_output.sampled_token_ids = non_spec_sampled_tokens
|
||||
sampler_output.logprobs = non_spec_logprobs
|
||||
sampler_output.hidden_states = non_spec_hidden_states
|
||||
(non_spec_target_token_ids, non_spec_target_probs,
|
||||
non_spec_target_logprobs,
|
||||
non_spec_target_hidden_states) = sampler_output_to_torch(
|
||||
[sampler_output], True)
|
||||
|
||||
return (target_token_ids, target_probs, target_logprobs,
|
||||
target_hidden_states, non_spec_target_token_ids,
|
||||
non_spec_target_probs, non_spec_target_logprobs,
|
||||
non_spec_target_hidden_states)
|
||||
return (spec_sampled_tokens, spec_probs, spec_logprobs,
|
||||
spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
|
||||
non_spec_logprobs, non_spec_hidden_states)
|
||||
|
||||
@staticmethod
|
||||
def _create_target_seq_id_iterator(
|
||||
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
|
||||
seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
|
||||
"""Create an iterator for creating target sequence ids.
|
||||
Target sequence ids are distinct from sequence ids because we create a
|
||||
distinct target sequence id for each proposal token to be scored.
|
||||
|
@ -417,8 +434,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||
"""
|
||||
return count(start=max(seq_ids) + 1)
|
||||
|
||||
@staticmethod
|
||||
def _get_token_ids_to_score(
|
||||
self,
|
||||
full_spec_token_ids: List[TokenId] # shape: [k]
|
||||
) -> List[List[TokenId]]:
|
||||
"""Given an int tensor of proposal token ids, return a list of
|
||||
|
@ -439,8 +456,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||
empty_token_ids: List[TokenId] = []
|
||||
|
||||
token_ids_to_score = [empty_token_ids]
|
||||
token_ids_to_score.extend([
|
||||
full_spec_token_ids[:i + 1]
|
||||
for i in range(len(full_spec_token_ids))
|
||||
])
|
||||
token_ids_to_score.extend(full_spec_token_ids[:i + 1]
|
||||
for i in range(len(full_spec_token_ids)))
|
||||
return token_ids_to_score
|
||||
|
|
|
@ -365,12 +365,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
# used during the prefill phase.
|
||||
# 2. Auto-disable enabled: The running queue size exceeds
|
||||
# the specified threshold.
|
||||
# 3. No request: There are no requests in the batch.
|
||||
# 3. No request: There are no requests in the batch, or
|
||||
# none of the requests in the batch have spec decoding enabled.
|
||||
# In any of these cases, the proposer and scorer workers
|
||||
# are called normally.
|
||||
no_spec = num_lookahead_slots == 0 or len(
|
||||
execute_model_req.seq_group_metadata_list
|
||||
) == 0 or disable_all_speculation
|
||||
no_spec = num_lookahead_slots == 0 or disable_all_speculation or all(
|
||||
sgm.num_speculative_tokens == 0
|
||||
for sgm in execute_model_req.seq_group_metadata_list)
|
||||
|
||||
# Broadcast how many lookahead slots are scheduled for this step, and
|
||||
# whether all speculation is disabled, to all non-driver workers.
|
||||
|
@ -415,10 +416,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
self, execute_model_req: ExecuteModelRequest) -> bool:
|
||||
# When the batch size is too large, disable speculative decoding
|
||||
# to stop trading off throughput for latency.
|
||||
disable_all_speculation = (execute_model_req.running_queue_size >=
|
||||
self.disable_by_batch_size)
|
||||
|
||||
return disable_all_speculation
|
||||
return (execute_model_req.running_queue_size >=
|
||||
self.disable_by_batch_size)
|
||||
|
||||
def _maybe_disable_speculative_tokens(
|
||||
self, disable_all_speculation: bool,
|
||||
|
@ -621,14 +620,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
_, spec_indices = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list,
|
||||
proposal_lens_list,
|
||||
select_proposal_len_zero=False)
|
||||
_, non_spec_indices = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list,
|
||||
proposal_lens_list,
|
||||
select_proposal_len_zero=True)
|
||||
(_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list, proposal_lens_list)
|
||||
original_indices = spec_indices + non_spec_indices
|
||||
|
||||
# Get probabilities of target model, excluding bonus token.
|
||||
|
|
|
@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer):
|
|||
|
||||
# Currently only proposal lens of 0 or the global batch proposal len
|
||||
# are supported.
|
||||
# If max_proposal_len is defined, then we shall no exceed this
|
||||
# If max_proposal_len is defined, then we shall not exceed this
|
||||
# quota for nonzero_proposal
|
||||
new_k = 0
|
||||
if (self.max_proposal_len is None
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -98,33 +98,26 @@ def create_sequence_group_output(
|
|||
|
||||
def split_batch_by_proposal_len(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_lens: List[int], select_proposal_len_zero: bool
|
||||
) -> Tuple[List[SequenceGroupMetadata], List[int]]:
|
||||
proposal_lens: List[int],
|
||||
) -> Tuple[Tuple[List[SequenceGroupMetadata], List[int]], Tuple[
|
||||
List[SequenceGroupMetadata], List[int]]]:
|
||||
"""Utility function that splits a batch based on whether the proposal len is
|
||||
zero or not. We should remove this once vLLM supports per-sequence proposal
|
||||
lens in a batch.
|
||||
"""
|
||||
|
||||
if select_proposal_len_zero:
|
||||
predicate = lambda proposal_len: proposal_len == 0
|
||||
else:
|
||||
predicate = lambda proposal_len: proposal_len != 0
|
||||
|
||||
indices = [
|
||||
i for i, (_, proposal_len
|
||||
) in enumerate(zip(seq_group_metadata_list, proposal_lens))
|
||||
if predicate(proposal_len)
|
||||
]
|
||||
seq_groups = [
|
||||
seq_group for seq_group, proposal_len in zip(
|
||||
seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
|
||||
]
|
||||
|
||||
return seq_groups, indices
|
||||
nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
|
||||
zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
|
||||
for i, (seq_group, proposal_len) in enumerate(
|
||||
zip(seq_group_metadata_list, proposal_lens)):
|
||||
seq_groups, indices = nonzero_lists if proposal_len else zero_lists
|
||||
seq_groups.append(seq_group)
|
||||
indices.append(i)
|
||||
return nonzero_lists, zero_lists
|
||||
|
||||
|
||||
def sampler_output_to_torch(
|
||||
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
|
||||
sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Utility function which converts a list of SamplerOutput to tensors.
|
||||
|
||||
|
@ -148,18 +141,12 @@ def sampler_output_to_torch(
|
|||
dim=0,
|
||||
)
|
||||
|
||||
if sampler_transposed:
|
||||
sampled_token_probs = sampled_token_probs.transpose(0, 1)
|
||||
|
||||
# shape: [batch_size, num_sampler_output, vocab_size]
|
||||
sampled_token_logprobs = torch.stack(
|
||||
[sampler_output.logprobs for sampler_output in sampler_output_list],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if sampler_transposed:
|
||||
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
|
||||
|
||||
# shape: [batch_size, num_sampler_output]
|
||||
sampled_token_ids = torch.stack(
|
||||
[
|
||||
|
@ -168,7 +155,10 @@ def sampler_output_to_torch(
|
|||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if sampler_transposed:
|
||||
sampled_token_probs = sampled_token_probs.transpose(0, 1)
|
||||
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
|
||||
sampled_token_ids = sampled_token_ids.transpose(0, 1)
|
||||
|
||||
if sampler_output_list[0].hidden_states is not None:
|
||||
|
|
Loading…
Reference in New Issue