mirror of https://github.com/vllm-project/vllm
[Spec Decoding] Streamline batch expansion tensor manipulation (#7851)
This commit is contained in:
parent
70c094ade6
commit
1856aff4d6
|
@ -55,10 +55,9 @@ def fake_sequence_group_metadata():
|
||||||
|
|
||||||
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
|
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
|
||||||
proposal_lens = [0, 1, 0]
|
proposal_lens = [0, 1, 0]
|
||||||
filtered_groups, indices = split_batch_by_proposal_len(
|
_, (filtered_groups,
|
||||||
fake_sequence_group_metadata,
|
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||||
proposal_lens,
|
proposal_lens)
|
||||||
select_proposal_len_zero=True)
|
|
||||||
|
|
||||||
expected_groups = [
|
expected_groups = [
|
||||||
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
|
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
|
||||||
|
@ -71,10 +70,9 @@ def test_filter_zero_length_proposals(fake_sequence_group_metadata):
|
||||||
|
|
||||||
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
|
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
|
||||||
proposal_lens = [0, 1, 2]
|
proposal_lens = [0, 1, 2]
|
||||||
filtered_groups, indices = split_batch_by_proposal_len(
|
(filtered_groups,
|
||||||
fake_sequence_group_metadata,
|
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||||
proposal_lens,
|
proposal_lens)
|
||||||
select_proposal_len_zero=False)
|
|
||||||
|
|
||||||
expected_groups = [
|
expected_groups = [
|
||||||
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
|
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
|
||||||
|
@ -86,8 +84,7 @@ def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
|
||||||
|
|
||||||
|
|
||||||
def test_empty_inputs():
|
def test_empty_inputs():
|
||||||
filtered_groups, indices = split_batch_by_proposal_len(
|
_, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
|
||||||
[], [], select_proposal_len_zero=True)
|
|
||||||
|
|
||||||
assert filtered_groups == []
|
assert filtered_groups == []
|
||||||
assert indices == []
|
assert indices == []
|
||||||
|
@ -95,10 +92,9 @@ def test_empty_inputs():
|
||||||
|
|
||||||
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
|
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
|
||||||
proposal_lens = [0, 0, 0]
|
proposal_lens = [0, 0, 0]
|
||||||
filtered_groups, indices = split_batch_by_proposal_len(
|
(filtered_groups,
|
||||||
fake_sequence_group_metadata,
|
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||||
proposal_lens,
|
proposal_lens)
|
||||||
select_proposal_len_zero=False)
|
|
||||||
|
|
||||||
assert filtered_groups == []
|
assert filtered_groups == []
|
||||||
assert indices == []
|
assert indices == []
|
||||||
|
@ -106,10 +102,9 @@ def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
|
||||||
|
|
||||||
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
|
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
|
||||||
proposal_lens = [1, 1, 1]
|
proposal_lens = [1, 1, 1]
|
||||||
filtered_groups, indices = split_batch_by_proposal_len(
|
_, (filtered_groups,
|
||||||
fake_sequence_group_metadata,
|
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||||
proposal_lens,
|
proposal_lens)
|
||||||
select_proposal_len_zero=True)
|
|
||||||
|
|
||||||
assert filtered_groups == []
|
assert filtered_groups == []
|
||||||
assert indices == []
|
assert indices == []
|
||||||
|
|
|
@ -10,8 +10,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
|
||||||
get_all_seq_ids)
|
get_all_seq_ids)
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeScorer, SpeculativeScores)
|
SpeculativeScorer, SpeculativeScores)
|
||||||
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
|
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
|
||||||
split_batch_by_proposal_len)
|
|
||||||
from vllm.worker.worker_base import WorkerBase
|
from vllm.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
SeqId = int
|
SeqId = int
|
||||||
|
@ -88,8 +87,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||||
target_sampler_output = target_sampler_output[0]
|
target_sampler_output = target_sampler_output[0]
|
||||||
|
|
||||||
(all_tokens, all_probs, spec_logprobs,
|
if not non_spec_indices:
|
||||||
all_hidden_states) = self._contract_batch(
|
# All sequence groups in batch have spec decoding enabled
|
||||||
|
contracted = self._contract_batch_all_spec(
|
||||||
|
target_sampler_output=target_sampler_output,
|
||||||
|
proposals=proposals,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Batch has a mix of spec decode enabled and disabled seq groups
|
||||||
|
contracted = self._contract_batch(
|
||||||
contracted_bs=len(execute_model_req.seq_group_metadata_list),
|
contracted_bs=len(execute_model_req.seq_group_metadata_list),
|
||||||
target_sampler_output=target_sampler_output,
|
target_sampler_output=target_sampler_output,
|
||||||
proposals=proposals,
|
proposals=proposals,
|
||||||
|
@ -99,6 +105,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||||
k=execute_model_req.num_lookahead_slots,
|
k=execute_model_req.num_lookahead_slots,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted
|
||||||
return SpeculativeScores(
|
return SpeculativeScores(
|
||||||
probs=all_probs,
|
probs=all_probs,
|
||||||
token_ids=all_tokens,
|
token_ids=all_tokens,
|
||||||
|
@ -121,14 +128,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||||
# proposal len. This adds some complexity (splitting the batch into spec
|
# proposal len. This adds some complexity (splitting the batch into spec
|
||||||
# and non spec sequences) and should be removed in the future. It can be
|
# and non spec sequences) and should be removed in the future. It can be
|
||||||
# done by supporting per-sequence proposal lens.
|
# done by supporting per-sequence proposal lens.
|
||||||
spec_seqs, spec_indices = split_batch_by_proposal_len(
|
(spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
|
||||||
seq_group_metadata_list,
|
split_batch_by_proposal_len(
|
||||||
proposal_lens_list,
|
seq_group_metadata_list, proposal_lens_list)
|
||||||
select_proposal_len_zero=False)
|
|
||||||
non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
|
|
||||||
seq_group_metadata_list,
|
|
||||||
proposal_lens_list,
|
|
||||||
select_proposal_len_zero=True)
|
|
||||||
|
|
||||||
target_seq_group_metadata_list = self._create_scoring_model_input(
|
target_seq_group_metadata_list = self._create_scoring_model_input(
|
||||||
seq_group_metadata_list=spec_seqs,
|
seq_group_metadata_list=spec_seqs,
|
||||||
|
@ -171,7 +173,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||||
# The number of tokens in the expanded batch used for speculation is
|
# The number of tokens in the expanded batch used for speculation is
|
||||||
# equal to the total expanded batch size minus the number of samples for
|
# equal to the total expanded batch size minus the number of samples for
|
||||||
# non-speculative sequences.
|
# non-speculative sequences.
|
||||||
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
|
non_spec_expanded_bs = len(non_spec_target_token_ids)
|
||||||
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
||||||
|
|
||||||
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
|
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
|
||||||
|
@ -181,7 +183,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||||
|
|
||||||
if target_hidden_states is not None:
|
if target_hidden_states is not None:
|
||||||
target_hidden_states = target_hidden_states.reshape(
|
target_hidden_states = target_hidden_states.reshape(
|
||||||
spec_expanded_bs, k + 1, target_hidden_states.shape[-1])
|
*target_token_ids.shape, target_hidden_states.shape[-1])
|
||||||
|
|
||||||
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
|
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
|
||||||
fill_value=-1)
|
fill_value=-1)
|
||||||
|
@ -196,24 +198,58 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||||
all_hidden_states = None
|
all_hidden_states = None
|
||||||
|
|
||||||
if non_spec_indices:
|
if non_spec_indices:
|
||||||
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
|
all_tokens[non_spec_indices, :1] = \
|
||||||
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
|
non_spec_target_token_ids.unsqueeze(1)
|
||||||
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
|
all_probs[non_spec_indices, :1, :] = \
|
||||||
|
non_spec_target_probs.unsqueeze(1)
|
||||||
|
all_logprobs[non_spec_indices, :1, :] = \
|
||||||
|
non_spec_target_logprobs.unsqueeze(1)
|
||||||
if all_hidden_states is not None:
|
if all_hidden_states is not None:
|
||||||
all_hidden_states[
|
assert non_spec_target_hidden_states is not None
|
||||||
non_spec_indices, :1, :] = non_spec_target_hidden_states
|
all_hidden_states[non_spec_indices, :1, :] = \
|
||||||
|
non_spec_target_hidden_states.unsqueeze(1)
|
||||||
|
|
||||||
if spec_indices:
|
if spec_indices:
|
||||||
all_tokens[spec_indices] = target_token_ids
|
all_tokens[spec_indices] = target_token_ids
|
||||||
all_probs[spec_indices] = target_probs
|
all_probs[spec_indices] = target_probs
|
||||||
all_logprobs[spec_indices] = target_logprobs
|
all_logprobs[spec_indices] = target_logprobs
|
||||||
|
|
||||||
if all_hidden_states is not None:
|
if all_hidden_states is not None:
|
||||||
all_hidden_states[spec_indices] = target_hidden_states
|
all_hidden_states[spec_indices] = target_hidden_states
|
||||||
|
|
||||||
return all_tokens, all_probs, all_logprobs, all_hidden_states
|
return all_tokens, all_probs, all_logprobs, all_hidden_states
|
||||||
|
|
||||||
|
def _contract_batch_all_spec(
|
||||||
|
self,
|
||||||
|
target_sampler_output: SamplerOutput,
|
||||||
|
proposals: SpeculativeProposals,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
||||||
|
Optional[torch.Tensor]]:
|
||||||
|
"""Contract the expanded batch back into its original size.
|
||||||
|
This maps the scores of speculative tokens back to their original
|
||||||
|
sequences.
|
||||||
|
|
||||||
|
It assumes all sequences in the batch were previously expanded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Map distinct sequences used to score each token
|
||||||
|
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
|
||||||
|
contracted_bs, k = proposals.proposal_token_ids.shape
|
||||||
|
|
||||||
|
# Reshape tensors to original batch size
|
||||||
|
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
|
||||||
|
contracted_bs, k + 1)
|
||||||
|
target_probs = target_sampler_output.sampled_token_probs.reshape(
|
||||||
|
*target_token_ids.shape, self._vocab_size)
|
||||||
|
target_logprobs = target_sampler_output.logprobs.reshape(
|
||||||
|
target_probs.shape)
|
||||||
|
target_hidden_states = target_sampler_output.hidden_states
|
||||||
|
if target_hidden_states is not None:
|
||||||
|
target_hidden_states = target_hidden_states.reshape(
|
||||||
|
*target_token_ids.shape, target_hidden_states.shape[-1])
|
||||||
|
|
||||||
|
return (target_token_ids, target_probs, target_logprobs,
|
||||||
|
target_hidden_states)
|
||||||
|
|
||||||
def _create_scoring_model_input(
|
def _create_scoring_model_input(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
@ -345,8 +381,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||||
token_chunk_size=1,
|
token_chunk_size=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _split_scoring_output(
|
def _split_scoring_output(
|
||||||
self, sampler_output: SamplerOutput, num_scoring_tokens: int
|
sampler_output: SamplerOutput, num_scoring_tokens: int
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
||||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||||
torch.Tensor, Optional[torch.Tensor]]:
|
torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
@ -361,10 +398,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||||
#
|
#
|
||||||
# First samples are from speculative scoring, latter samples are non-
|
# First samples are from speculative scoring, latter samples are non-
|
||||||
# speculative samples.
|
# speculative samples.
|
||||||
split_sizes = [
|
split_sizes = (num_scoring_tokens,
|
||||||
num_scoring_tokens,
|
sampler_output.sampled_token_ids.numel() -
|
||||||
sampler_output.sampled_token_ids.numel() - num_scoring_tokens
|
num_scoring_tokens)
|
||||||
]
|
|
||||||
(spec_probs, non_spec_probs
|
(spec_probs, non_spec_probs
|
||||||
) = sampler_output.sampled_token_probs.split(split_sizes)
|
) = sampler_output.sampled_token_probs.split(split_sizes)
|
||||||
(spec_sampled_tokens, non_spec_sampled_tokens
|
(spec_sampled_tokens, non_spec_sampled_tokens
|
||||||
|
@ -382,32 +418,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||||
else:
|
else:
|
||||||
spec_hidden_states, non_spec_hidden_states = None, None
|
spec_hidden_states, non_spec_hidden_states = None, None
|
||||||
|
|
||||||
# Convert scores to tensors.
|
return (spec_sampled_tokens, spec_probs, spec_logprobs,
|
||||||
sampler_output.sampled_token_probs = spec_probs
|
spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
|
||||||
sampler_output.sampled_token_ids = spec_sampled_tokens
|
non_spec_logprobs, non_spec_hidden_states)
|
||||||
sampler_output.logprobs = spec_logprobs
|
|
||||||
sampler_output.hidden_states = spec_hidden_states
|
|
||||||
(target_token_ids, target_probs, target_logprobs,
|
|
||||||
target_hidden_states) = sampler_output_to_torch([sampler_output],
|
|
||||||
True)
|
|
||||||
|
|
||||||
# Convert non-speculative output tokens to tensors.
|
|
||||||
sampler_output.sampled_token_probs = non_spec_probs
|
|
||||||
sampler_output.sampled_token_ids = non_spec_sampled_tokens
|
|
||||||
sampler_output.logprobs = non_spec_logprobs
|
|
||||||
sampler_output.hidden_states = non_spec_hidden_states
|
|
||||||
(non_spec_target_token_ids, non_spec_target_probs,
|
|
||||||
non_spec_target_logprobs,
|
|
||||||
non_spec_target_hidden_states) = sampler_output_to_torch(
|
|
||||||
[sampler_output], True)
|
|
||||||
|
|
||||||
return (target_token_ids, target_probs, target_logprobs,
|
|
||||||
target_hidden_states, non_spec_target_token_ids,
|
|
||||||
non_spec_target_probs, non_spec_target_logprobs,
|
|
||||||
non_spec_target_hidden_states)
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _create_target_seq_id_iterator(
|
def _create_target_seq_id_iterator(
|
||||||
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
|
seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
|
||||||
"""Create an iterator for creating target sequence ids.
|
"""Create an iterator for creating target sequence ids.
|
||||||
Target sequence ids are distinct from sequence ids because we create a
|
Target sequence ids are distinct from sequence ids because we create a
|
||||||
distinct target sequence id for each proposal token to be scored.
|
distinct target sequence id for each proposal token to be scored.
|
||||||
|
@ -417,8 +434,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||||
"""
|
"""
|
||||||
return count(start=max(seq_ids) + 1)
|
return count(start=max(seq_ids) + 1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _get_token_ids_to_score(
|
def _get_token_ids_to_score(
|
||||||
self,
|
|
||||||
full_spec_token_ids: List[TokenId] # shape: [k]
|
full_spec_token_ids: List[TokenId] # shape: [k]
|
||||||
) -> List[List[TokenId]]:
|
) -> List[List[TokenId]]:
|
||||||
"""Given an int tensor of proposal token ids, return a list of
|
"""Given an int tensor of proposal token ids, return a list of
|
||||||
|
@ -439,8 +456,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||||
empty_token_ids: List[TokenId] = []
|
empty_token_ids: List[TokenId] = []
|
||||||
|
|
||||||
token_ids_to_score = [empty_token_ids]
|
token_ids_to_score = [empty_token_ids]
|
||||||
token_ids_to_score.extend([
|
token_ids_to_score.extend(full_spec_token_ids[:i + 1]
|
||||||
full_spec_token_ids[:i + 1]
|
for i in range(len(full_spec_token_ids)))
|
||||||
for i in range(len(full_spec_token_ids))
|
|
||||||
])
|
|
||||||
return token_ids_to_score
|
return token_ids_to_score
|
||||||
|
|
|
@ -365,12 +365,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||||
# used during the prefill phase.
|
# used during the prefill phase.
|
||||||
# 2. Auto-disable enabled: The running queue size exceeds
|
# 2. Auto-disable enabled: The running queue size exceeds
|
||||||
# the specified threshold.
|
# the specified threshold.
|
||||||
# 3. No request: There are no requests in the batch.
|
# 3. No request: There are no requests in the batch, or
|
||||||
|
# none of the requests in the batch have spec decoding enabled.
|
||||||
# In any of these cases, the proposer and scorer workers
|
# In any of these cases, the proposer and scorer workers
|
||||||
# are called normally.
|
# are called normally.
|
||||||
no_spec = num_lookahead_slots == 0 or len(
|
no_spec = num_lookahead_slots == 0 or disable_all_speculation or all(
|
||||||
execute_model_req.seq_group_metadata_list
|
sgm.num_speculative_tokens == 0
|
||||||
) == 0 or disable_all_speculation
|
for sgm in execute_model_req.seq_group_metadata_list)
|
||||||
|
|
||||||
# Broadcast how many lookahead slots are scheduled for this step, and
|
# Broadcast how many lookahead slots are scheduled for this step, and
|
||||||
# whether all speculation is disabled, to all non-driver workers.
|
# whether all speculation is disabled, to all non-driver workers.
|
||||||
|
@ -415,11 +416,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||||
self, execute_model_req: ExecuteModelRequest) -> bool:
|
self, execute_model_req: ExecuteModelRequest) -> bool:
|
||||||
# When the batch size is too large, disable speculative decoding
|
# When the batch size is too large, disable speculative decoding
|
||||||
# to stop trading off throughput for latency.
|
# to stop trading off throughput for latency.
|
||||||
disable_all_speculation = (execute_model_req.running_queue_size >=
|
return (execute_model_req.running_queue_size >=
|
||||||
self.disable_by_batch_size)
|
self.disable_by_batch_size)
|
||||||
|
|
||||||
return disable_all_speculation
|
|
||||||
|
|
||||||
def _maybe_disable_speculative_tokens(
|
def _maybe_disable_speculative_tokens(
|
||||||
self, disable_all_speculation: bool,
|
self, disable_all_speculation: bool,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
||||||
|
@ -621,14 +620,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||||
# proposal len. This adds some complexity (splitting the batch into spec
|
# proposal len. This adds some complexity (splitting the batch into spec
|
||||||
# and non spec sequences) and should be removed in the future. It can be
|
# and non spec sequences) and should be removed in the future. It can be
|
||||||
# done by supporting per-sequence proposal lens.
|
# done by supporting per-sequence proposal lens.
|
||||||
_, spec_indices = split_batch_by_proposal_len(
|
(_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
|
||||||
seq_group_metadata_list,
|
seq_group_metadata_list, proposal_lens_list)
|
||||||
proposal_lens_list,
|
|
||||||
select_proposal_len_zero=False)
|
|
||||||
_, non_spec_indices = split_batch_by_proposal_len(
|
|
||||||
seq_group_metadata_list,
|
|
||||||
proposal_lens_list,
|
|
||||||
select_proposal_len_zero=True)
|
|
||||||
original_indices = spec_indices + non_spec_indices
|
original_indices = spec_indices + non_spec_indices
|
||||||
|
|
||||||
# Get probabilities of target model, excluding bonus token.
|
# Get probabilities of target model, excluding bonus token.
|
||||||
|
|
|
@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer):
|
||||||
|
|
||||||
# Currently only proposal lens of 0 or the global batch proposal len
|
# Currently only proposal lens of 0 or the global batch proposal len
|
||||||
# are supported.
|
# are supported.
|
||||||
# If max_proposal_len is defined, then we shall no exceed this
|
# If max_proposal_len is defined, then we shall not exceed this
|
||||||
# quota for nonzero_proposal
|
# quota for nonzero_proposal
|
||||||
new_k = 0
|
new_k = 0
|
||||||
if (self.max_proposal_len is None
|
if (self.max_proposal_len is None
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -98,33 +98,26 @@ def create_sequence_group_output(
|
||||||
|
|
||||||
def split_batch_by_proposal_len(
|
def split_batch_by_proposal_len(
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
proposal_lens: List[int], select_proposal_len_zero: bool
|
proposal_lens: List[int],
|
||||||
) -> Tuple[List[SequenceGroupMetadata], List[int]]:
|
) -> Tuple[Tuple[List[SequenceGroupMetadata], List[int]], Tuple[
|
||||||
|
List[SequenceGroupMetadata], List[int]]]:
|
||||||
"""Utility function that splits a batch based on whether the proposal len is
|
"""Utility function that splits a batch based on whether the proposal len is
|
||||||
zero or not. We should remove this once vLLM supports per-sequence proposal
|
zero or not. We should remove this once vLLM supports per-sequence proposal
|
||||||
lens in a batch.
|
lens in a batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if select_proposal_len_zero:
|
nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
|
||||||
predicate = lambda proposal_len: proposal_len == 0
|
zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
|
||||||
else:
|
for i, (seq_group, proposal_len) in enumerate(
|
||||||
predicate = lambda proposal_len: proposal_len != 0
|
zip(seq_group_metadata_list, proposal_lens)):
|
||||||
|
seq_groups, indices = nonzero_lists if proposal_len else zero_lists
|
||||||
indices = [
|
seq_groups.append(seq_group)
|
||||||
i for i, (_, proposal_len
|
indices.append(i)
|
||||||
) in enumerate(zip(seq_group_metadata_list, proposal_lens))
|
return nonzero_lists, zero_lists
|
||||||
if predicate(proposal_len)
|
|
||||||
]
|
|
||||||
seq_groups = [
|
|
||||||
seq_group for seq_group, proposal_len in zip(
|
|
||||||
seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
|
|
||||||
]
|
|
||||||
|
|
||||||
return seq_groups, indices
|
|
||||||
|
|
||||||
|
|
||||||
def sampler_output_to_torch(
|
def sampler_output_to_torch(
|
||||||
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
|
sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""Utility function which converts a list of SamplerOutput to tensors.
|
"""Utility function which converts a list of SamplerOutput to tensors.
|
||||||
|
|
||||||
|
@ -148,18 +141,12 @@ def sampler_output_to_torch(
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if sampler_transposed:
|
|
||||||
sampled_token_probs = sampled_token_probs.transpose(0, 1)
|
|
||||||
|
|
||||||
# shape: [batch_size, num_sampler_output, vocab_size]
|
# shape: [batch_size, num_sampler_output, vocab_size]
|
||||||
sampled_token_logprobs = torch.stack(
|
sampled_token_logprobs = torch.stack(
|
||||||
[sampler_output.logprobs for sampler_output in sampler_output_list],
|
[sampler_output.logprobs for sampler_output in sampler_output_list],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if sampler_transposed:
|
|
||||||
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
|
|
||||||
|
|
||||||
# shape: [batch_size, num_sampler_output]
|
# shape: [batch_size, num_sampler_output]
|
||||||
sampled_token_ids = torch.stack(
|
sampled_token_ids = torch.stack(
|
||||||
[
|
[
|
||||||
|
@ -168,7 +155,10 @@ def sampler_output_to_torch(
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if sampler_transposed:
|
if sampler_transposed:
|
||||||
|
sampled_token_probs = sampled_token_probs.transpose(0, 1)
|
||||||
|
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
|
||||||
sampled_token_ids = sampled_token_ids.transpose(0, 1)
|
sampled_token_ids = sampled_token_ids.transpose(0, 1)
|
||||||
|
|
||||||
if sampler_output_list[0].hidden_states is not None:
|
if sampler_output_list[0].hidden_states is not None:
|
||||||
|
|
Loading…
Reference in New Issue