[Spec Decoding] Streamline batch expansion tensor manipulation (#7851)

This commit is contained in:
Nick Hill 2024-08-25 15:45:14 -07:00 committed by GitHub
parent 70c094ade6
commit 1856aff4d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 117 additions and 124 deletions

View File

@ -55,10 +55,9 @@ def fake_sequence_group_metadata():
def test_filter_zero_length_proposals(fake_sequence_group_metadata): def test_filter_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 0] proposal_lens = [0, 1, 0]
filtered_groups, indices = split_batch_by_proposal_len( _, (filtered_groups,
fake_sequence_group_metadata, indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens, proposal_lens)
select_proposal_len_zero=True)
expected_groups = [ expected_groups = [
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2] fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
@ -71,10 +70,9 @@ def test_filter_zero_length_proposals(fake_sequence_group_metadata):
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata): def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 2] proposal_lens = [0, 1, 2]
filtered_groups, indices = split_batch_by_proposal_len( (filtered_groups,
fake_sequence_group_metadata, indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens, proposal_lens)
select_proposal_len_zero=False)
expected_groups = [ expected_groups = [
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2] fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
@ -86,8 +84,7 @@ def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
def test_empty_inputs(): def test_empty_inputs():
filtered_groups, indices = split_batch_by_proposal_len( _, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
[], [], select_proposal_len_zero=True)
assert filtered_groups == [] assert filtered_groups == []
assert indices == [] assert indices == []
@ -95,10 +92,9 @@ def test_empty_inputs():
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata): def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
proposal_lens = [0, 0, 0] proposal_lens = [0, 0, 0]
filtered_groups, indices = split_batch_by_proposal_len( (filtered_groups,
fake_sequence_group_metadata, indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens, proposal_lens)
select_proposal_len_zero=False)
assert filtered_groups == [] assert filtered_groups == []
assert indices == [] assert indices == []
@ -106,10 +102,9 @@ def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata): def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
proposal_lens = [1, 1, 1] proposal_lens = [1, 1, 1]
filtered_groups, indices = split_batch_by_proposal_len( _, (filtered_groups,
fake_sequence_group_metadata, indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens, proposal_lens)
select_proposal_len_zero=True)
assert filtered_groups == [] assert filtered_groups == []
assert indices == [] assert indices == []

View File

@ -10,8 +10,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
get_all_seq_ids) get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch, from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
split_batch_by_proposal_len)
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
SeqId = int SeqId = int
@ -88,17 +87,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
assert len(target_sampler_output) == 1, "expected single-step output" assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0] target_sampler_output = target_sampler_output[0]
(all_tokens, all_probs, spec_logprobs, if not non_spec_indices:
all_hidden_states) = self._contract_batch( # All sequence groups in batch have spec decoding enabled
contracted_bs=len(execute_model_req.seq_group_metadata_list), contracted = self._contract_batch_all_spec(
target_sampler_output=target_sampler_output, target_sampler_output=target_sampler_output,
proposals=proposals, proposals=proposals,
num_scoring_tokens=num_scoring_tokens, )
non_spec_indices=non_spec_indices, else:
spec_indices=spec_indices, # Batch has a mix of spec decode enabled and disabled seq groups
k=execute_model_req.num_lookahead_slots, contracted = self._contract_batch(
) contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted
return SpeculativeScores( return SpeculativeScores(
probs=all_probs, probs=all_probs,
token_ids=all_tokens, token_ids=all_tokens,
@ -121,14 +128,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# proposal len. This adds some complexity (splitting the batch into spec # proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be # and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens. # done by supporting per-sequence proposal lens.
spec_seqs, spec_indices = split_batch_by_proposal_len( (spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
seq_group_metadata_list, split_batch_by_proposal_len(
proposal_lens_list, seq_group_metadata_list, proposal_lens_list)
select_proposal_len_zero=False)
non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=True)
target_seq_group_metadata_list = self._create_scoring_model_input( target_seq_group_metadata_list = self._create_scoring_model_input(
seq_group_metadata_list=spec_seqs, seq_group_metadata_list=spec_seqs,
@ -171,7 +173,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# The number of tokens in the expanded batch used for speculation is # The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for # equal to the total expanded batch size minus the number of samples for
# non-speculative sequences. # non-speculative sequences.
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape non_spec_expanded_bs = len(non_spec_target_token_ids)
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1) target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
@ -181,7 +183,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
if target_hidden_states is not None: if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape( target_hidden_states = target_hidden_states.reshape(
spec_expanded_bs, k + 1, target_hidden_states.shape[-1]) *target_token_ids.shape, target_hidden_states.shape[-1])
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1) fill_value=-1)
@ -196,24 +198,58 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
all_hidden_states = None all_hidden_states = None
if non_spec_indices: if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_tokens[non_spec_indices, :1] = \
all_probs[non_spec_indices, :1, :] = non_spec_target_probs non_spec_target_token_ids.unsqueeze(1)
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs all_probs[non_spec_indices, :1, :] = \
non_spec_target_probs.unsqueeze(1)
all_logprobs[non_spec_indices, :1, :] = \
non_spec_target_logprobs.unsqueeze(1)
if all_hidden_states is not None: if all_hidden_states is not None:
all_hidden_states[ assert non_spec_target_hidden_states is not None
non_spec_indices, :1, :] = non_spec_target_hidden_states all_hidden_states[non_spec_indices, :1, :] = \
non_spec_target_hidden_states.unsqueeze(1)
if spec_indices: if spec_indices:
all_tokens[spec_indices] = target_token_ids all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs all_logprobs[spec_indices] = target_logprobs
if all_hidden_states is not None: if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states all_hidden_states[spec_indices] = target_hidden_states
return all_tokens, all_probs, all_logprobs, all_hidden_states return all_tokens, all_probs, all_logprobs, all_hidden_states
def _contract_batch_all_spec(
self,
target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
It assumes all sequences in the batch were previously expanded.
"""
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs, k = proposals.proposal_token_ids.shape
# Reshape tensors to original batch size
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
contracted_bs, k + 1)
target_probs = target_sampler_output.sampled_token_probs.reshape(
*target_token_ids.shape, self._vocab_size)
target_logprobs = target_sampler_output.logprobs.reshape(
target_probs.shape)
target_hidden_states = target_sampler_output.hidden_states
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
*target_token_ids.shape, target_hidden_states.shape[-1])
return (target_token_ids, target_probs, target_logprobs,
target_hidden_states)
def _create_scoring_model_input( def _create_scoring_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
@ -345,8 +381,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_chunk_size=1, token_chunk_size=1,
) )
@staticmethod
def _split_scoring_output( def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]: torch.Tensor, Optional[torch.Tensor]]:
@ -361,10 +398,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# #
# First samples are from speculative scoring, latter samples are non- # First samples are from speculative scoring, latter samples are non-
# speculative samples. # speculative samples.
split_sizes = [ split_sizes = (num_scoring_tokens,
num_scoring_tokens, sampler_output.sampled_token_ids.numel() -
sampler_output.sampled_token_ids.numel() - num_scoring_tokens num_scoring_tokens)
]
(spec_probs, non_spec_probs (spec_probs, non_spec_probs
) = sampler_output.sampled_token_probs.split(split_sizes) ) = sampler_output.sampled_token_probs.split(split_sizes)
(spec_sampled_tokens, non_spec_sampled_tokens (spec_sampled_tokens, non_spec_sampled_tokens
@ -382,32 +418,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
else: else:
spec_hidden_states, non_spec_hidden_states = None, None spec_hidden_states, non_spec_hidden_states = None, None
# Convert scores to tensors. return (spec_sampled_tokens, spec_probs, spec_logprobs,
sampler_output.sampled_token_probs = spec_probs spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
sampler_output.sampled_token_ids = spec_sampled_tokens non_spec_logprobs, non_spec_hidden_states)
sampler_output.logprobs = spec_logprobs
sampler_output.hidden_states = spec_hidden_states
(target_token_ids, target_probs, target_logprobs,
target_hidden_states) = sampler_output_to_torch([sampler_output],
True)
# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
sampler_output.logprobs = non_spec_logprobs
sampler_output.hidden_states = non_spec_hidden_states
(non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs,
non_spec_target_hidden_states) = sampler_output_to_torch(
[sampler_output], True)
return (target_token_ids, target_probs, target_logprobs,
target_hidden_states, non_spec_target_token_ids,
non_spec_target_probs, non_spec_target_logprobs,
non_spec_target_hidden_states)
@staticmethod
def _create_target_seq_id_iterator( def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
"""Create an iterator for creating target sequence ids. """Create an iterator for creating target sequence ids.
Target sequence ids are distinct from sequence ids because we create a Target sequence ids are distinct from sequence ids because we create a
distinct target sequence id for each proposal token to be scored. distinct target sequence id for each proposal token to be scored.
@ -417,8 +434,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
""" """
return count(start=max(seq_ids) + 1) return count(start=max(seq_ids) + 1)
@staticmethod
def _get_token_ids_to_score( def _get_token_ids_to_score(
self,
full_spec_token_ids: List[TokenId] # shape: [k] full_spec_token_ids: List[TokenId] # shape: [k]
) -> List[List[TokenId]]: ) -> List[List[TokenId]]:
"""Given an int tensor of proposal token ids, return a list of """Given an int tensor of proposal token ids, return a list of
@ -439,8 +456,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
empty_token_ids: List[TokenId] = [] empty_token_ids: List[TokenId] = []
token_ids_to_score = [empty_token_ids] token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend([ token_ids_to_score.extend(full_spec_token_ids[:i + 1]
full_spec_token_ids[:i + 1] for i in range(len(full_spec_token_ids)))
for i in range(len(full_spec_token_ids))
])
return token_ids_to_score return token_ids_to_score

View File

@ -365,12 +365,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# used during the prefill phase. # used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds # 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold. # the specified threshold.
# 3. No request: There are no requests in the batch. # 3. No request: There are no requests in the batch, or
# none of the requests in the batch have spec decoding enabled.
# In any of these cases, the proposer and scorer workers # In any of these cases, the proposer and scorer workers
# are called normally. # are called normally.
no_spec = num_lookahead_slots == 0 or len( no_spec = num_lookahead_slots == 0 or disable_all_speculation or all(
execute_model_req.seq_group_metadata_list sgm.num_speculative_tokens == 0
) == 0 or disable_all_speculation for sgm in execute_model_req.seq_group_metadata_list)
# Broadcast how many lookahead slots are scheduled for this step, and # Broadcast how many lookahead slots are scheduled for this step, and
# whether all speculation is disabled, to all non-driver workers. # whether all speculation is disabled, to all non-driver workers.
@ -415,10 +416,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self, execute_model_req: ExecuteModelRequest) -> bool: self, execute_model_req: ExecuteModelRequest) -> bool:
# When the batch size is too large, disable speculative decoding # When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency. # to stop trading off throughput for latency.
disable_all_speculation = (execute_model_req.running_queue_size >= return (execute_model_req.running_queue_size >=
self.disable_by_batch_size) self.disable_by_batch_size)
return disable_all_speculation
def _maybe_disable_speculative_tokens( def _maybe_disable_speculative_tokens(
self, disable_all_speculation: bool, self, disable_all_speculation: bool,
@ -621,14 +620,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# proposal len. This adds some complexity (splitting the batch into spec # proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be # and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens. # done by supporting per-sequence proposal lens.
_, spec_indices = split_batch_by_proposal_len( (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
seq_group_metadata_list, seq_group_metadata_list, proposal_lens_list)
proposal_lens_list,
select_proposal_len_zero=False)
_, non_spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=True)
original_indices = spec_indices + non_spec_indices original_indices = spec_indices + non_spec_indices
# Get probabilities of target model, excluding bonus token. # Get probabilities of target model, excluding bonus token.

View File

@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer):
# Currently only proposal lens of 0 or the global batch proposal len # Currently only proposal lens of 0 or the global batch proposal len
# are supported. # are supported.
# If max_proposal_len is defined, then we shall no exceed this # If max_proposal_len is defined, then we shall not exceed this
# quota for nonzero_proposal # quota for nonzero_proposal
new_k = 0 new_k = 0
if (self.max_proposal_len is None if (self.max_proposal_len is None

View File

@ -1,6 +1,6 @@
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Sequence, Tuple
import torch import torch
@ -98,33 +98,26 @@ def create_sequence_group_output(
def split_batch_by_proposal_len( def split_batch_by_proposal_len(
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_lens: List[int], select_proposal_len_zero: bool proposal_lens: List[int],
) -> Tuple[List[SequenceGroupMetadata], List[int]]: ) -> Tuple[Tuple[List[SequenceGroupMetadata], List[int]], Tuple[
List[SequenceGroupMetadata], List[int]]]:
"""Utility function that splits a batch based on whether the proposal len is """Utility function that splits a batch based on whether the proposal len is
zero or not. We should remove this once vLLM supports per-sequence proposal zero or not. We should remove this once vLLM supports per-sequence proposal
lens in a batch. lens in a batch.
""" """
if select_proposal_len_zero: nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
predicate = lambda proposal_len: proposal_len == 0 zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
else: for i, (seq_group, proposal_len) in enumerate(
predicate = lambda proposal_len: proposal_len != 0 zip(seq_group_metadata_list, proposal_lens)):
seq_groups, indices = nonzero_lists if proposal_len else zero_lists
indices = [ seq_groups.append(seq_group)
i for i, (_, proposal_len indices.append(i)
) in enumerate(zip(seq_group_metadata_list, proposal_lens)) return nonzero_lists, zero_lists
if predicate(proposal_len)
]
seq_groups = [
seq_group for seq_group, proposal_len in zip(
seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
]
return seq_groups, indices
def sampler_output_to_torch( def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput], sampler_transposed: bool sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Utility function which converts a list of SamplerOutput to tensors. """Utility function which converts a list of SamplerOutput to tensors.
@ -148,18 +141,12 @@ def sampler_output_to_torch(
dim=0, dim=0,
) )
if sampler_transposed:
sampled_token_probs = sampled_token_probs.transpose(0, 1)
# shape: [batch_size, num_sampler_output, vocab_size] # shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs = torch.stack( sampled_token_logprobs = torch.stack(
[sampler_output.logprobs for sampler_output in sampler_output_list], [sampler_output.logprobs for sampler_output in sampler_output_list],
dim=0, dim=0,
) )
if sampler_transposed:
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
# shape: [batch_size, num_sampler_output] # shape: [batch_size, num_sampler_output]
sampled_token_ids = torch.stack( sampled_token_ids = torch.stack(
[ [
@ -168,7 +155,10 @@ def sampler_output_to_torch(
], ],
dim=0, dim=0,
) )
if sampler_transposed: if sampler_transposed:
sampled_token_probs = sampled_token_probs.transpose(0, 1)
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
sampled_token_ids = sampled_token_ids.transpose(0, 1) sampled_token_ids = sampled_token_ids.transpose(0, 1)
if sampler_output_list[0].hidden_states is not None: if sampler_output_list[0].hidden_states is not None: