[Spec Decoding] Streamline batch expansion tensor manipulation (#7851)

This commit is contained in:
Nick Hill 2024-08-25 15:45:14 -07:00 committed by GitHub
parent 70c094ade6
commit 1856aff4d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 117 additions and 124 deletions

View File

@ -55,10 +55,9 @@ def fake_sequence_group_metadata():
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 0]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=True)
_, (filtered_groups,
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
expected_groups = [
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
@ -71,10 +70,9 @@ def test_filter_zero_length_proposals(fake_sequence_group_metadata):
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 2]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=False)
(filtered_groups,
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
expected_groups = [
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
@ -86,8 +84,7 @@ def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
def test_empty_inputs():
filtered_groups, indices = split_batch_by_proposal_len(
[], [], select_proposal_len_zero=True)
_, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
assert filtered_groups == []
assert indices == []
@ -95,10 +92,9 @@ def test_empty_inputs():
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
proposal_lens = [0, 0, 0]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=False)
(filtered_groups,
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
assert filtered_groups == []
assert indices == []
@ -106,10 +102,9 @@ def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
proposal_lens = [1, 1, 1]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=True)
_, (filtered_groups,
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
assert filtered_groups == []
assert indices == []

View File

@ -10,8 +10,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
split_batch_by_proposal_len)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
from vllm.worker.worker_base import WorkerBase
SeqId = int
@ -88,17 +87,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
(all_tokens, all_probs, spec_logprobs,
all_hidden_states) = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
if not non_spec_indices:
# All sequence groups in batch have spec decoding enabled
contracted = self._contract_batch_all_spec(
target_sampler_output=target_sampler_output,
proposals=proposals,
)
else:
# Batch has a mix of spec decode enabled and disabled seq groups
contracted = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted
return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
@ -121,14 +128,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
spec_seqs, spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=False)
non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=True)
(spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
split_batch_by_proposal_len(
seq_group_metadata_list, proposal_lens_list)
target_seq_group_metadata_list = self._create_scoring_model_input(
seq_group_metadata_list=spec_seqs,
@ -171,7 +173,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences.
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
non_spec_expanded_bs = len(non_spec_target_token_ids)
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
@ -181,7 +183,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
spec_expanded_bs, k + 1, target_hidden_states.shape[-1])
*target_token_ids.shape, target_hidden_states.shape[-1])
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
@ -196,24 +198,58 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
all_hidden_states = None
if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
all_tokens[non_spec_indices, :1] = \
non_spec_target_token_ids.unsqueeze(1)
all_probs[non_spec_indices, :1, :] = \
non_spec_target_probs.unsqueeze(1)
all_logprobs[non_spec_indices, :1, :] = \
non_spec_target_logprobs.unsqueeze(1)
if all_hidden_states is not None:
all_hidden_states[
non_spec_indices, :1, :] = non_spec_target_hidden_states
assert non_spec_target_hidden_states is not None
all_hidden_states[non_spec_indices, :1, :] = \
non_spec_target_hidden_states.unsqueeze(1)
if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs
if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states
return all_tokens, all_probs, all_logprobs, all_hidden_states
def _contract_batch_all_spec(
self,
target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
It assumes all sequences in the batch were previously expanded.
"""
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs, k = proposals.proposal_token_ids.shape
# Reshape tensors to original batch size
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
contracted_bs, k + 1)
target_probs = target_sampler_output.sampled_token_probs.reshape(
*target_token_ids.shape, self._vocab_size)
target_logprobs = target_sampler_output.logprobs.reshape(
target_probs.shape)
target_hidden_states = target_sampler_output.hidden_states
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
*target_token_ids.shape, target_hidden_states.shape[-1])
return (target_token_ids, target_probs, target_logprobs,
target_hidden_states)
def _create_scoring_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
@ -345,8 +381,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_chunk_size=1,
)
@staticmethod
def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int
sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
@ -361,10 +398,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
#
# First samples are from speculative scoring, latter samples are non-
# speculative samples.
split_sizes = [
num_scoring_tokens,
sampler_output.sampled_token_ids.numel() - num_scoring_tokens
]
split_sizes = (num_scoring_tokens,
sampler_output.sampled_token_ids.numel() -
num_scoring_tokens)
(spec_probs, non_spec_probs
) = sampler_output.sampled_token_probs.split(split_sizes)
(spec_sampled_tokens, non_spec_sampled_tokens
@ -382,32 +418,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
else:
spec_hidden_states, non_spec_hidden_states = None, None
# Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs
sampler_output.sampled_token_ids = spec_sampled_tokens
sampler_output.logprobs = spec_logprobs
sampler_output.hidden_states = spec_hidden_states
(target_token_ids, target_probs, target_logprobs,
target_hidden_states) = sampler_output_to_torch([sampler_output],
True)
# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
sampler_output.logprobs = non_spec_logprobs
sampler_output.hidden_states = non_spec_hidden_states
(non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs,
non_spec_target_hidden_states) = sampler_output_to_torch(
[sampler_output], True)
return (target_token_ids, target_probs, target_logprobs,
target_hidden_states, non_spec_target_token_ids,
non_spec_target_probs, non_spec_target_logprobs,
non_spec_target_hidden_states)
return (spec_sampled_tokens, spec_probs, spec_logprobs,
spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
non_spec_logprobs, non_spec_hidden_states)
@staticmethod
def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
"""Create an iterator for creating target sequence ids.
Target sequence ids are distinct from sequence ids because we create a
distinct target sequence id for each proposal token to be scored.
@ -417,8 +434,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
"""
return count(start=max(seq_ids) + 1)
@staticmethod
def _get_token_ids_to_score(
self,
full_spec_token_ids: List[TokenId] # shape: [k]
) -> List[List[TokenId]]:
"""Given an int tensor of proposal token ids, return a list of
@ -439,8 +456,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
empty_token_ids: List[TokenId] = []
token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend([
full_spec_token_ids[:i + 1]
for i in range(len(full_spec_token_ids))
])
token_ids_to_score.extend(full_spec_token_ids[:i + 1]
for i in range(len(full_spec_token_ids)))
return token_ids_to_score

View File

@ -365,12 +365,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# 3. No request: There are no requests in the batch.
# 3. No request: There are no requests in the batch, or
# none of the requests in the batch have spec decoding enabled.
# In any of these cases, the proposer and scorer workers
# are called normally.
no_spec = num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list
) == 0 or disable_all_speculation
no_spec = num_lookahead_slots == 0 or disable_all_speculation or all(
sgm.num_speculative_tokens == 0
for sgm in execute_model_req.seq_group_metadata_list)
# Broadcast how many lookahead slots are scheduled for this step, and
# whether all speculation is disabled, to all non-driver workers.
@ -415,10 +416,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self, execute_model_req: ExecuteModelRequest) -> bool:
# When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency.
disable_all_speculation = (execute_model_req.running_queue_size >=
self.disable_by_batch_size)
return disable_all_speculation
return (execute_model_req.running_queue_size >=
self.disable_by_batch_size)
def _maybe_disable_speculative_tokens(
self, disable_all_speculation: bool,
@ -621,14 +620,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
_, spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=False)
_, non_spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=True)
(_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
seq_group_metadata_list, proposal_lens_list)
original_indices = spec_indices + non_spec_indices
# Get probabilities of target model, excluding bonus token.

View File

@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer):
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# If max_proposal_len is defined, then we shall no exceed this
# If max_proposal_len is defined, then we shall not exceed this
# quota for nonzero_proposal
new_k = 0
if (self.max_proposal_len is None

View File

@ -1,6 +1,6 @@
import time
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Sequence, Tuple
import torch
@ -98,33 +98,26 @@ def create_sequence_group_output(
def split_batch_by_proposal_len(
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_lens: List[int], select_proposal_len_zero: bool
) -> Tuple[List[SequenceGroupMetadata], List[int]]:
proposal_lens: List[int],
) -> Tuple[Tuple[List[SequenceGroupMetadata], List[int]], Tuple[
List[SequenceGroupMetadata], List[int]]]:
"""Utility function that splits a batch based on whether the proposal len is
zero or not. We should remove this once vLLM supports per-sequence proposal
lens in a batch.
"""
if select_proposal_len_zero:
predicate = lambda proposal_len: proposal_len == 0
else:
predicate = lambda proposal_len: proposal_len != 0
indices = [
i for i, (_, proposal_len
) in enumerate(zip(seq_group_metadata_list, proposal_lens))
if predicate(proposal_len)
]
seq_groups = [
seq_group for seq_group, proposal_len in zip(
seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
]
return seq_groups, indices
nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
for i, (seq_group, proposal_len) in enumerate(
zip(seq_group_metadata_list, proposal_lens)):
seq_groups, indices = nonzero_lists if proposal_len else zero_lists
seq_groups.append(seq_group)
indices.append(i)
return nonzero_lists, zero_lists
def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Utility function which converts a list of SamplerOutput to tensors.
@ -148,18 +141,12 @@ def sampler_output_to_torch(
dim=0,
)
if sampler_transposed:
sampled_token_probs = sampled_token_probs.transpose(0, 1)
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs = torch.stack(
[sampler_output.logprobs for sampler_output in sampler_output_list],
dim=0,
)
if sampler_transposed:
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
# shape: [batch_size, num_sampler_output]
sampled_token_ids = torch.stack(
[
@ -168,7 +155,10 @@ def sampler_output_to_torch(
],
dim=0,
)
if sampler_transposed:
sampled_token_probs = sampled_token_probs.transpose(0, 1)
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
sampled_token_ids = sampled_token_ids.transpose(0, 1)
if sampler_output_list[0].hidden_states is not None: