diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index 18dbdd5bc9..06780d4b8c 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -55,10 +55,9 @@ def fake_sequence_group_metadata(): def test_filter_zero_length_proposals(fake_sequence_group_metadata): proposal_lens = [0, 1, 0] - filtered_groups, indices = split_batch_by_proposal_len( - fake_sequence_group_metadata, - proposal_lens, - select_proposal_len_zero=True) + _, (filtered_groups, + indices) = split_batch_by_proposal_len(fake_sequence_group_metadata, + proposal_lens) expected_groups = [ fake_sequence_group_metadata[0], fake_sequence_group_metadata[2] @@ -71,10 +70,9 @@ def test_filter_zero_length_proposals(fake_sequence_group_metadata): def test_filter_non_zero_length_proposals(fake_sequence_group_metadata): proposal_lens = [0, 1, 2] - filtered_groups, indices = split_batch_by_proposal_len( - fake_sequence_group_metadata, - proposal_lens, - select_proposal_len_zero=False) + (filtered_groups, + indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata, + proposal_lens) expected_groups = [ fake_sequence_group_metadata[1], fake_sequence_group_metadata[2] @@ -86,8 +84,7 @@ def test_filter_non_zero_length_proposals(fake_sequence_group_metadata): def test_empty_inputs(): - filtered_groups, indices = split_batch_by_proposal_len( - [], [], select_proposal_len_zero=True) + _, (filtered_groups, indices) = split_batch_by_proposal_len([], []) assert filtered_groups == [] assert indices == [] @@ -95,10 +92,9 @@ def test_empty_inputs(): def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata): proposal_lens = [0, 0, 0] - filtered_groups, indices = split_batch_by_proposal_len( - fake_sequence_group_metadata, - proposal_lens, - select_proposal_len_zero=False) + (filtered_groups, + indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata, + proposal_lens) assert filtered_groups == [] assert indices == [] @@ -106,10 +102,9 @@ def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata): def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata): proposal_lens = [1, 1, 1] - filtered_groups, indices = split_batch_by_proposal_len( - fake_sequence_group_metadata, - proposal_lens, - select_proposal_len_zero=True) + _, (filtered_groups, + indices) = split_batch_by_proposal_len(fake_sequence_group_metadata, + proposal_lens) assert filtered_groups == [] assert indices == [] diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index ad6f3f3138..8a691d65aa 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -10,8 +10,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) -from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch, - split_batch_by_proposal_len) +from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len from vllm.worker.worker_base import WorkerBase SeqId = int @@ -88,17 +87,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] - (all_tokens, all_probs, spec_logprobs, - all_hidden_states) = self._contract_batch( - contracted_bs=len(execute_model_req.seq_group_metadata_list), - target_sampler_output=target_sampler_output, - proposals=proposals, - num_scoring_tokens=num_scoring_tokens, - non_spec_indices=non_spec_indices, - spec_indices=spec_indices, - k=execute_model_req.num_lookahead_slots, - ) + if not non_spec_indices: + # All sequence groups in batch have spec decoding enabled + contracted = self._contract_batch_all_spec( + target_sampler_output=target_sampler_output, + proposals=proposals, + ) + else: + # Batch has a mix of spec decode enabled and disabled seq groups + contracted = self._contract_batch( + contracted_bs=len(execute_model_req.seq_group_metadata_list), + target_sampler_output=target_sampler_output, + proposals=proposals, + num_scoring_tokens=num_scoring_tokens, + non_spec_indices=non_spec_indices, + spec_indices=spec_indices, + k=execute_model_req.num_lookahead_slots, + ) + all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted return SpeculativeScores( probs=all_probs, token_ids=all_tokens, @@ -121,14 +128,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): # proposal len. This adds some complexity (splitting the batch into spec # and non spec sequences) and should be removed in the future. It can be # done by supporting per-sequence proposal lens. - spec_seqs, spec_indices = split_batch_by_proposal_len( - seq_group_metadata_list, - proposal_lens_list, - select_proposal_len_zero=False) - non_spec_seqs, non_spec_indices = split_batch_by_proposal_len( - seq_group_metadata_list, - proposal_lens_list, - select_proposal_len_zero=True) + (spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \ + split_batch_by_proposal_len( + seq_group_metadata_list, proposal_lens_list) target_seq_group_metadata_list = self._create_scoring_model_input( seq_group_metadata_list=spec_seqs, @@ -171,7 +173,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): # The number of tokens in the expanded batch used for speculation is # equal to the total expanded batch size minus the number of samples for # non-speculative sequences. - non_spec_expanded_bs, _ = non_spec_target_token_ids.shape + non_spec_expanded_bs = len(non_spec_target_token_ids) spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1) @@ -181,7 +183,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): if target_hidden_states is not None: target_hidden_states = target_hidden_states.reshape( - spec_expanded_bs, k + 1, target_hidden_states.shape[-1]) + *target_token_ids.shape, target_hidden_states.shape[-1]) all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), fill_value=-1) @@ -196,24 +198,58 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): all_hidden_states = None if non_spec_indices: - all_tokens[non_spec_indices, :1] = non_spec_target_token_ids - all_probs[non_spec_indices, :1, :] = non_spec_target_probs - all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs - + all_tokens[non_spec_indices, :1] = \ + non_spec_target_token_ids.unsqueeze(1) + all_probs[non_spec_indices, :1, :] = \ + non_spec_target_probs.unsqueeze(1) + all_logprobs[non_spec_indices, :1, :] = \ + non_spec_target_logprobs.unsqueeze(1) if all_hidden_states is not None: - all_hidden_states[ - non_spec_indices, :1, :] = non_spec_target_hidden_states + assert non_spec_target_hidden_states is not None + all_hidden_states[non_spec_indices, :1, :] = \ + non_spec_target_hidden_states.unsqueeze(1) if spec_indices: all_tokens[spec_indices] = target_token_ids all_probs[spec_indices] = target_probs all_logprobs[spec_indices] = target_logprobs - if all_hidden_states is not None: all_hidden_states[spec_indices] = target_hidden_states return all_tokens, all_probs, all_logprobs, all_hidden_states + def _contract_batch_all_spec( + self, + target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: + """Contract the expanded batch back into its original size. + This maps the scores of speculative tokens back to their original + sequences. + + It assumes all sequences in the batch were previously expanded. + """ + + # Map distinct sequences used to score each token + # of shape [batch_size * k + 1] back to [batch_size, k + 1]. + contracted_bs, k = proposals.proposal_token_ids.shape + + # Reshape tensors to original batch size + target_token_ids = target_sampler_output.sampled_token_ids.reshape( + contracted_bs, k + 1) + target_probs = target_sampler_output.sampled_token_probs.reshape( + *target_token_ids.shape, self._vocab_size) + target_logprobs = target_sampler_output.logprobs.reshape( + target_probs.shape) + target_hidden_states = target_sampler_output.hidden_states + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.reshape( + *target_token_ids.shape, target_hidden_states.shape[-1]) + + return (target_token_ids, target_probs, target_logprobs, + target_hidden_states) + def _create_scoring_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -345,8 +381,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): token_chunk_size=1, ) + @staticmethod def _split_scoring_output( - self, sampler_output: SamplerOutput, num_scoring_tokens: int + sampler_output: SamplerOutput, num_scoring_tokens: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: @@ -361,10 +398,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): # # First samples are from speculative scoring, latter samples are non- # speculative samples. - split_sizes = [ - num_scoring_tokens, - sampler_output.sampled_token_ids.numel() - num_scoring_tokens - ] + split_sizes = (num_scoring_tokens, + sampler_output.sampled_token_ids.numel() - + num_scoring_tokens) (spec_probs, non_spec_probs ) = sampler_output.sampled_token_probs.split(split_sizes) (spec_sampled_tokens, non_spec_sampled_tokens @@ -382,32 +418,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): else: spec_hidden_states, non_spec_hidden_states = None, None - # Convert scores to tensors. - sampler_output.sampled_token_probs = spec_probs - sampler_output.sampled_token_ids = spec_sampled_tokens - sampler_output.logprobs = spec_logprobs - sampler_output.hidden_states = spec_hidden_states - (target_token_ids, target_probs, target_logprobs, - target_hidden_states) = sampler_output_to_torch([sampler_output], - True) - - # Convert non-speculative output tokens to tensors. - sampler_output.sampled_token_probs = non_spec_probs - sampler_output.sampled_token_ids = non_spec_sampled_tokens - sampler_output.logprobs = non_spec_logprobs - sampler_output.hidden_states = non_spec_hidden_states - (non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs, - non_spec_target_hidden_states) = sampler_output_to_torch( - [sampler_output], True) - - return (target_token_ids, target_probs, target_logprobs, - target_hidden_states, non_spec_target_token_ids, - non_spec_target_probs, non_spec_target_logprobs, - non_spec_target_hidden_states) + return (spec_sampled_tokens, spec_probs, spec_logprobs, + spec_hidden_states, non_spec_sampled_tokens, non_spec_probs, + non_spec_logprobs, non_spec_hidden_states) + @staticmethod def _create_target_seq_id_iterator( - self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: + seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: """Create an iterator for creating target sequence ids. Target sequence ids are distinct from sequence ids because we create a distinct target sequence id for each proposal token to be scored. @@ -417,8 +434,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): """ return count(start=max(seq_ids) + 1) + @staticmethod def _get_token_ids_to_score( - self, full_spec_token_ids: List[TokenId] # shape: [k] ) -> List[List[TokenId]]: """Given an int tensor of proposal token ids, return a list of @@ -439,8 +456,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): empty_token_ids: List[TokenId] = [] token_ids_to_score = [empty_token_ids] - token_ids_to_score.extend([ - full_spec_token_ids[:i + 1] - for i in range(len(full_spec_token_ids)) - ]) + token_ids_to_score.extend(full_spec_token_ids[:i + 1] + for i in range(len(full_spec_token_ids))) return token_ids_to_score diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 2762b83880..9b1f21fcb4 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -365,12 +365,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): # used during the prefill phase. # 2. Auto-disable enabled: The running queue size exceeds # the specified threshold. - # 3. No request: There are no requests in the batch. + # 3. No request: There are no requests in the batch, or + # none of the requests in the batch have spec decoding enabled. # In any of these cases, the proposer and scorer workers # are called normally. - no_spec = num_lookahead_slots == 0 or len( - execute_model_req.seq_group_metadata_list - ) == 0 or disable_all_speculation + no_spec = num_lookahead_slots == 0 or disable_all_speculation or all( + sgm.num_speculative_tokens == 0 + for sgm in execute_model_req.seq_group_metadata_list) # Broadcast how many lookahead slots are scheduled for this step, and # whether all speculation is disabled, to all non-driver workers. @@ -415,10 +416,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): self, execute_model_req: ExecuteModelRequest) -> bool: # When the batch size is too large, disable speculative decoding # to stop trading off throughput for latency. - disable_all_speculation = (execute_model_req.running_queue_size >= - self.disable_by_batch_size) - - return disable_all_speculation + return (execute_model_req.running_queue_size >= + self.disable_by_batch_size) def _maybe_disable_speculative_tokens( self, disable_all_speculation: bool, @@ -621,14 +620,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): # proposal len. This adds some complexity (splitting the batch into spec # and non spec sequences) and should be removed in the future. It can be # done by supporting per-sequence proposal lens. - _, spec_indices = split_batch_by_proposal_len( - seq_group_metadata_list, - proposal_lens_list, - select_proposal_len_zero=False) - _, non_spec_indices = split_batch_by_proposal_len( - seq_group_metadata_list, - proposal_lens_list, - select_proposal_len_zero=True) + (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len( + seq_group_metadata_list, proposal_lens_list) original_indices = spec_indices + non_spec_indices # Get probabilities of target model, excluding bonus token. diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 28f7f7eb06..aa993e539b 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer): # Currently only proposal lens of 0 or the global batch proposal len # are supported. - # If max_proposal_len is defined, then we shall no exceed this + # If max_proposal_len is defined, then we shall not exceed this # quota for nonzero_proposal new_k = 0 if (self.max_proposal_len is None diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 9315cd0f75..d18ee47e23 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -1,6 +1,6 @@ import time from contextlib import contextmanager -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Sequence, Tuple import torch @@ -98,33 +98,26 @@ def create_sequence_group_output( def split_batch_by_proposal_len( seq_group_metadata_list: List[SequenceGroupMetadata], - proposal_lens: List[int], select_proposal_len_zero: bool -) -> Tuple[List[SequenceGroupMetadata], List[int]]: + proposal_lens: List[int], +) -> Tuple[Tuple[List[SequenceGroupMetadata], List[int]], Tuple[ + List[SequenceGroupMetadata], List[int]]]: """Utility function that splits a batch based on whether the proposal len is zero or not. We should remove this once vLLM supports per-sequence proposal lens in a batch. """ - if select_proposal_len_zero: - predicate = lambda proposal_len: proposal_len == 0 - else: - predicate = lambda proposal_len: proposal_len != 0 - - indices = [ - i for i, (_, proposal_len - ) in enumerate(zip(seq_group_metadata_list, proposal_lens)) - if predicate(proposal_len) - ] - seq_groups = [ - seq_group for seq_group, proposal_len in zip( - seq_group_metadata_list, proposal_lens) if predicate(proposal_len) - ] - - return seq_groups, indices + nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], []) + zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], []) + for i, (seq_group, proposal_len) in enumerate( + zip(seq_group_metadata_list, proposal_lens)): + seq_groups, indices = nonzero_lists if proposal_len else zero_lists + seq_groups.append(seq_group) + indices.append(i) + return nonzero_lists, zero_lists def sampler_output_to_torch( - sampler_output_list: List[SamplerOutput], sampler_transposed: bool + sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Utility function which converts a list of SamplerOutput to tensors. @@ -148,18 +141,12 @@ def sampler_output_to_torch( dim=0, ) - if sampler_transposed: - sampled_token_probs = sampled_token_probs.transpose(0, 1) - # shape: [batch_size, num_sampler_output, vocab_size] sampled_token_logprobs = torch.stack( [sampler_output.logprobs for sampler_output in sampler_output_list], dim=0, ) - if sampler_transposed: - sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1) - # shape: [batch_size, num_sampler_output] sampled_token_ids = torch.stack( [ @@ -168,7 +155,10 @@ def sampler_output_to_torch( ], dim=0, ) + if sampler_transposed: + sampled_token_probs = sampled_token_probs.transpose(0, 1) + sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1) sampled_token_ids = sampled_token_ids.transpose(0, 1) if sampler_output_list[0].hidden_states is not None: