Align vLLM's beam search implementation with HF generate (#857)

This commit is contained in:
Zhuohan Li 2023-09-04 17:29:42 -07:00 committed by GitHub
parent e15932bb60
commit 002800f081
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 596 additions and 260 deletions

View File

@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
+ kv_caches: List[KVCache],
+ input_metadata: InputMetadata,
+ cache_events: Optional[List[torch.cuda.Event]],
+) -> Dict[int, SequenceOutputs]:
+) -> SamplerOutput:
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.

View File

@ -67,8 +67,8 @@ class HfRunner:
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
output_ids = output_ids[0].cpu().tolist()
)
output_ids = output_ids.cpu().tolist()
outputs.append((output_ids, output_str))
return outputs
@ -77,8 +77,34 @@ class HfRunner:
prompts: List[str],
max_tokens: int,
) -> List[Tuple[List[int], str]]:
return self.generate(prompts, do_sample=False,
max_new_tokens=max_tokens)
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
outputs[i] = (output_ids[0], output_str[0])
return outputs
def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens,
num_beams=beam_width,
num_return_sequences=beam_width)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
for j in range(len(output_ids)):
output_ids[j] = [
x for x in output_ids[j]
if x != self.tokenizer.pad_token_id
]
outputs[i] = (output_ids, output_str)
return outputs
@pytest.fixture
@ -107,15 +133,20 @@ class VllmRunner:
prompts: List[str],
sampling_params: SamplingParams,
) -> List[Tuple[List[int], str]]:
req_outputs = self.model.generate(
prompts, sampling_params=sampling_params)
req_outputs = self.model.generate(prompts,
sampling_params=sampling_params)
outputs = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
output_str = req_output.outputs[0].text
output_ids = req_output.outputs[0].token_ids
outputs.append((prompt_ids + output_ids, prompt_str + output_str))
req_sample_output_ids = []
req_sample_output_strs = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
def generate_greedy(
@ -124,7 +155,22 @@ class VllmRunner:
max_tokens: int,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
return self.generate(prompts, greedy_params)
outputs = self.generate(prompts, greedy_params)
return [(output_ids[0], output_str[0]) for output_ids, output_str in
outputs]
def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[int], str]]:
beam_search_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens)
outputs = self.generate(prompts, beam_search_params)
return outputs
@pytest.fixture

View File

@ -0,0 +1,46 @@
"""Compare the outputs of HF and vLLM when using beam search.
Run `pytest tests/samplers/test_beam_search.py --forked`.
"""
import pytest
# FIXME(zhuohan): The test can not pass if we:
# 1. Increase max_tokens to 256.
# 2. Increase beam_width to 8.
# 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS = [128]
BEAM_WIDTHS = [4]
MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
def test_beam_search_single_input(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
del vllm_model
for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
vllm_output_ids, _ = vllm_outputs[i]
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")

View File

@ -172,9 +172,7 @@ class BlockSpaceManager:
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs():
if seq.is_finished():
continue
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
@ -203,9 +201,7 @@ class BlockSpaceManager:
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
# GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs():
if seq.is_finished():
continue
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]

View File

@ -7,8 +7,7 @@ from vllm.core.block_manager import BlockSpaceManager
from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs,
SequenceStatus)
SequenceGroupMetadata, SequenceStatus)
logger = init_logger(__name__)
@ -76,6 +75,7 @@ class Scheduler:
num_cpu_blocks=self.cache_config.num_cpu_blocks,
)
# TODO(zhuohan): Use deque instead of list for better performance.
# Sequence groups in the WAITING state.
self.waiting: List[SequenceGroup] = []
# Sequence groups in the RUNNING state.
@ -96,10 +96,11 @@ class Scheduler:
if seq_group.request_id in request_ids:
# Remove the sequence group from the state queue.
state_queue.remove(seq_group)
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
if seq.is_finished():
continue
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq)
request_ids.remove(seq_group.request_id)
if not request_ids:
return
@ -123,6 +124,10 @@ class Scheduler:
if not self.swapped:
ignored_seq_groups: List[SequenceGroup] = []
scheduled: List[SequenceGroup] = []
# The total number of sequences on the fly, including the
# requests in the generation phase.
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running)
num_batched_tokens = 0
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
@ -130,6 +135,9 @@ class Scheduler:
while self.waiting:
seq_group = self.waiting[0]
assert seq_group.num_seqs() == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if num_prompt_tokens > self.prompt_limit:
logger.warning(
@ -152,11 +160,7 @@ class Scheduler:
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(
status=SequenceStatus.WAITING)
num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
num_new_seqs = seq_group.get_max_num_running_seqs()
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
@ -165,6 +169,7 @@ class Scheduler:
self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
num_curr_seqs += num_new_seqs
scheduled.append(seq_group)
if scheduled:
@ -210,30 +215,32 @@ class Scheduler:
# Swap in the sequence groups in the SWAPPED state if possible.
self.swapped = self.policy.sort_by_priority(now, self.swapped)
while self.swapped and not blocks_to_swap_out:
seq_group = self.swapped[0]
# If the sequence group has been preempted in this step, stop.
if seq_group in preempted:
break
# If the sequence group cannot be swapped in, stop.
if not self.block_manager.can_swap_in(seq_group):
break
if not preempted:
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running)
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
while self.swapped:
seq_group = self.swapped[0]
# If the sequence group cannot be swapped in, stop.
if not self.block_manager.can_swap_in(seq_group):
break
seq_group = self.swapped.pop(0)
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy)
self.running.append(seq_group)
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.get_max_num_running_seqs()
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.swapped.pop(0)
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy)
num_curr_seqs += num_new_seqs
self.running.append(seq_group)
# Each sequence in the generation phase only takes one token slot.
# Therefore, the number of batched tokens is equal to the number of
# sequences in the RUNNING state.
num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
@ -275,40 +282,10 @@ class Scheduler:
seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs
def update(
self,
seq_outputs: Dict[int, SequenceOutputs],
) -> List[SequenceGroup]:
scheduled: List[SequenceGroup] = []
for seq_group in self.running:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
if seq.seq_id in seq_outputs:
scheduled.append(seq_group)
break
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self.block_manager.fork(parent_seq, child_seq)
# Update the scheduled sequences and free blocks.
for seq_group in scheduled:
# Process beam search results before processing the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam
# search). Free the current sequence.
self.block_manager.free(seq)
# Fork the parent sequence.
parent_seq = seq_group.find(output.parent_seq_id)
parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq)
# Process the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append_token_id(output.output_token, output.logprobs)
return scheduled
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
seq.status = finish_status
def free_seq(self, seq: Sequence) -> None:
self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None:
@ -345,8 +322,8 @@ class Scheduler:
# If preemption mode is not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
# (e.g., beam search), recomputation is not supported. In such a case,
# we use swapping instead.
# (e.g., beam search), recomputation is not currently supported. In
# such a case, we use swapping instead.
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
# As swapped sequences are prioritized over waiting sequences,
# sequence groups with multiple sequences are implicitly prioritized
@ -354,8 +331,7 @@ class Scheduler:
# TODO(woosuk): Support recomputation for sequence groups with multiple
# sequences. This may require a more sophisticated CUDA kernel.
if preemption_mode is None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
if len(seqs) == 1:
if seq_group.get_max_num_running_seqs() == 1:
preemption_mode = PreemptionMode.RECOMPUTE
else:
preemption_mode = PreemptionMode.SWAP

View File

@ -11,7 +11,8 @@ from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata,
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs,
SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
@ -258,14 +259,11 @@ class LLMEngine:
# Create the sequences.
block_size = self.cache_config.block_size
seqs: List[Sequence] = []
for _ in range(sampling_params.best_of):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seqs.append(seq)
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
# Create the sequence group.
seq_group = SequenceGroup(request_id, seqs, sampling_params,
seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time)
# Add the sequence group to the scheduler.
@ -303,22 +301,230 @@ class LLMEngine:
]
return seq_group_metadata_list, scheduler_outputs, None
def _process_worker_outputs(
self, output,
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
# Update the scheduler with the model outputs.
seq_groups = self.scheduler.update(output)
def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = (current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
if early_stopping is False:
highest_attainable_score = (best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() +
sampling_params.max_tokens,
self.scheduler_config.max_model_len)
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
return current_worst_score >= highest_attainable_score
def _process_sequence_group_samples(
self, seq_group: SequenceGroup,
samples: List[SequenceOutputs]) -> None:
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
}
for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []
# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutputs] = parent_child_dict[
parent.seq_id]
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
self._decode_sequence(seq)
self._check_stop(seq, seq_group.sampling_params)
# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs = []
unselected_child_seqs = []
beam_width = seq_group.sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False)
for seq in existing_finished_seqs]
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
if seq.is_finished()]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
if not seq.is_finished()]
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id),
reverse=True)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
seq_group.sampling_params.early_stopping,
seq_group.sampling_params, best_running_seq, current_worst_seq)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq)
def _process_model_outputs(
self, output: SamplerOutput,
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, samples in zip(scheduled_seq_groups, output):
self._process_sequence_group_samples(seq_group, samples)
# Decode the sequences.
self._decode_sequences(seq_groups)
# Stop the sequences that meet the stopping criteria.
self._stop_sequences(seq_groups)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups:
for seq_group in (scheduled_seq_groups +
scheduler_outputs.ignored_seq_groups):
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
@ -351,7 +557,7 @@ class LLMEngine:
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)
return self._process_worker_outputs(output, scheduler_outputs)
return self._process_model_outputs(output, scheduler_outputs)
def _log_system_stats(
self,
@ -416,55 +622,44 @@ class LLMEngine:
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
"""Decodes the sequence outputs."""
for seq_group in seq_groups:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_token, new_output_text = detokenize_incrementally(
self.tokenizer,
seq.output_tokens,
seq.get_last_token_id(),
skip_special_tokens=True,
)
if new_token is not None:
seq.output_tokens.append(new_token)
seq.output_text = new_output_text
def _decode_sequence(self, seq: Sequence) -> None:
"""Decodes the new token for a sequence."""
new_token, new_output_text = detokenize_incrementally(
self.tokenizer,
seq.output_tokens,
seq.get_last_token_id(),
skip_special_tokens=True,
)
if new_token is not None:
seq.output_tokens.append(new_token)
seq.output_text = new_output_text
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
def _check_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences."""
for seq_group in seq_groups:
sampling_params = seq_group.sampling_params
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Check if the sequence has generated a stop string.
stopped = False
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_STOPPED)
stopped = True
break
if stopped:
continue
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
# Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos:
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_STOPPED)
continue
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == self.tokenizer.eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED
return
def _run_workers(
self,

View File

@ -9,7 +9,7 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.tensor_parallel import (
gather_from_tensor_model_parallel_region)
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput, SequenceOutputs
_SAMPLING_EPS = 1e-5
@ -39,7 +39,7 @@ class Sampler(nn.Module):
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
@ -292,7 +292,13 @@ def _sample_from_prompt(
if sampling_params.use_beam_search:
# Beam search.
beam_width = sampling_params.best_of
_, next_token_ids = torch.topk(prob, beam_width)
# Sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
_, next_token_ids = torch.topk(prob, 2 * beam_width)
next_token_ids = next_token_ids.tolist()
elif sampling_params.temperature < _SAMPLING_EPS:
# Greedy sampling.
@ -330,29 +336,11 @@ def _sample_from_generation_tokens(
vocab_size = logprobs.size(-1)
beam_width = len(seq_ids)
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
topk_ids = topk_ids.tolist()
seq_idx = [i // vocab_size for i in topk_ids]
beam_seq_ids = [seq_ids[i] for i in seq_idx]
token_ids = [i % vocab_size for i in topk_ids]
beam_outputs: Dict[int, Tuple[int, int]] = {}
outstanding_beams: List[Tuple[int, int]] = []
# If a beam survives, continue with it.
for seq_id, token_id in zip(beam_seq_ids, token_ids):
if seq_id not in beam_outputs:
beam_outputs[seq_id] = (seq_id, token_id)
else:
outstanding_beams.append((seq_id, token_id))
# If a beam is discarded, fork another beam.
for seq_id in seq_ids:
if seq_id not in beam_outputs:
beam_outputs[seq_id] = outstanding_beams.pop()
assert not outstanding_beams
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
parent_seq_ids = [seq_ids[i] for i in seq_idx]
next_token_ids = [i % vocab_size for i in topk_ids]
elif sampling_params.temperature < _SAMPLING_EPS:
# Greedy sampling.
assert len(seq_ids) == 1
@ -374,16 +362,18 @@ def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
input_metadata: InputMetadata,
) -> Dict[int, SequenceOutputs]:
seq_outputs: Dict[int, SequenceOutputs] = {}
) -> SamplerOutput:
seq_outputs: SamplerOutput = []
# TODO(woosuk): Optimize.
idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_group_outputs: List[SequenceOutputs] = []
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
# Generate the next tokens for a prompt input.
assert len(seq_ids) == sampling_params.best_of
assert len(seq_ids) == 1, "Prompt input should have only one seq."
parent_seq_id = seq_ids[0]
prob = probs[idx]
logprob = logprobs[idx]
idx += 1
@ -395,17 +385,18 @@ def _sample(
sampling_params.logprobs)
# Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
for next_token_id in next_token_ids:
output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id,
next_token_id,
output_logprobs)
seq_group_outputs.append(
SequenceOutputs(parent_seq_id, next_token_id,
output_logprobs))
else:
# Generate the next tokens for generation tokens.
prob = probs[idx:idx + len(seq_ids)]
logprob = logprobs[idx:idx + len(seq_ids)]
idx += len(seq_ids)
num_parent_seqs = len(seq_ids)
prob = probs[idx:idx + num_parent_seqs]
logprob = logprobs[idx:idx + num_parent_seqs]
idx += num_parent_seqs
# Sample the next tokens.
seq_logprobs = [
@ -422,17 +413,15 @@ def _sample(
logprob[j], sampling_params.logprobs)
# Build the output.
for seq_id, parent_seq_id, next_token_id in zip(
seq_ids, parent_seq_ids, next_token_ids):
for parent_seq_id, next_token_id in zip(parent_seq_ids,
next_token_ids):
j = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[j,
next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(
seq_id,
parent_seq_id,
next_token_id,
output_logprobs,
)
seq_group_outputs.append(
SequenceOutputs(parent_seq_id, next_token_id,
output_logprobs))
seq_outputs.append(seq_group_outputs)
return seq_outputs

View File

@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
@ -41,7 +41,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.aquila import AquilaConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -273,7 +273,7 @@ class AquilaForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,

View File

@ -23,12 +23,11 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
import math
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
@ -42,6 +41,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -290,7 +290,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,

View File

@ -21,7 +21,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
import math
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -264,7 +264,7 @@ class BloomForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,

View File

@ -19,7 +19,7 @@
"""PyTorch Falcon model."""
import math
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
@ -38,7 +38,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
reduce_from_tensor_model_parallel_region)
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import RWConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -397,7 +397,7 @@ class FalconForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.transformer(
input_ids,
positions,

View File

@ -21,7 +21,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
@ -38,7 +38,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -218,7 +218,7 @@ class GPT2LMHeadModel(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,

View File

@ -22,7 +22,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
@ -39,7 +39,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -246,7 +246,7 @@ class GPTBigCodeForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,

View File

@ -20,7 +20,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -203,7 +203,7 @@ class GPTJForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,

View File

@ -20,7 +20,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -215,7 +215,7 @@ class GPTNeoXForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.embed_out.weight, hidden_states,

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
@ -17,7 +17,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -218,7 +218,7 @@ class InternLMForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,

View File

@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
@ -43,7 +43,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -256,7 +256,7 @@ class LlamaForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,

View File

@ -1,7 +1,7 @@
# coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
@ -16,7 +16,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -230,7 +230,7 @@ class MPTForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,

View File

@ -21,7 +21,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -282,7 +282,7 @@ class OPTForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,

View File

@ -8,7 +8,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
@ -32,7 +32,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.sequence import SequenceOutputs
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.qwen import QWenConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -235,7 +235,7 @@ class QWenLMHeadModel(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,

View File

@ -75,10 +75,12 @@ class RequestOutput:
# Get the top-n sequences.
n = seq_group.sampling_params.n
seqs = seq_group.get_seqs()
assert n <= len(seqs)
sorted_seqs = sorted(seqs,
key=lambda seq: seq.get_cumulative_logprob(),
reverse=True)
if seq_group.sampling_params.use_beam_search:
sorting_key = lambda seq: seq.get_beam_search_score(
seq_group.sampling_params.length_penalty)
else:
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
# Create the outputs.

View File

@ -34,6 +34,15 @@ class SamplingParams:
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
early_stopping: Controls the stopping condition for beam search. It
accepts the following values: `True`, where the generation stops as
soon as there are `best_of` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very
unlikely to find better candidates; `"never"`, where the beam search
procedure only stops when there cannot be better candidates
(canonical beam search algorithm).
stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
ignore_eos: Whether to ignore the EOS token and continue generating
@ -52,6 +61,8 @@ class SamplingParams:
top_p: float = 1.0,
top_k: int = -1,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Union[None, str, List[str]] = None,
ignore_eos: bool = False,
max_tokens: int = 16,
@ -65,6 +76,8 @@ class SamplingParams:
self.top_p = top_p
self.top_k = top_k
self.use_beam_search = use_beam_search
self.length_penalty = length_penalty
self.early_stopping = early_stopping
if stop is None:
self.stop = []
elif isinstance(stop, str):
@ -78,9 +91,11 @@ class SamplingParams:
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
elif self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self._verify_greedy_sampling()
else:
self._verify_non_beam_search()
if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self._verify_greedy_sampling()
def _verify_args(self) -> None:
if self.n < 1:
@ -119,6 +134,20 @@ class SamplingParams:
raise ValueError("top_p must be 1 when using beam search.")
if self.top_k != -1:
raise ValueError("top_k must be -1 when using beam search.")
if self.early_stopping not in [True, False, "never"]:
raise ValueError(
f"early_stopping must be True, False, or 'never', "
f"got {self.early_stopping}.")
def _verify_non_beam_search(self) -> None:
if self.early_stopping is not False:
raise ValueError("early_stopping is not effective and must be "
"False when not using beam search.")
if (self.length_penalty < 1.0 - _SAMPLING_EPS
or self.length_penalty > 1.0 + _SAMPLING_EPS):
raise ValueError(
"length_penalty is not effective and must be the "
"default value of 1.0 when not using beam search.")
def _verify_greedy_sampling(self) -> None:
if self.best_of > 1:
@ -138,6 +167,8 @@ class SamplingParams:
f"top_p={self.top_p}, "
f"top_k={self.top_k}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "
f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "

View File

@ -69,6 +69,9 @@ class SequenceData:
def get_len(self) -> int:
return len(self.output_token_ids) + len(self.prompt_token_ids)
def get_prompt_len(self) -> int:
return len(self.prompt_token_ids)
def get_output_len(self) -> int:
return len(self.output_token_ids)
@ -155,6 +158,9 @@ class Sequence:
def get_len(self) -> int:
return self.data.get_len()
def get_prompt_len(self) -> int:
return self.data.get_prompt_len()
def get_output_len(self) -> int:
return self.data.get_output_len()
@ -170,14 +176,32 @@ class Sequence:
def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob
def get_beam_search_score(self,
length_penalty: float = 0.0,
seq_len: Optional[int] = None,
eos_token_id: Optional[int] = None) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
if seq_len is None:
seq_len = self.get_len()
# Note: HF implementation does not count the EOS token
# towards the length, we align with that here for testing.
if (eos_token_id is not None
and self.get_last_token_id() == eos_token_id):
seq_len -= 1
return self.get_cumulative_logprob() / (seq_len**length_penalty)
def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status)
def fork(self, child_seq: "Sequence") -> None:
child_seq.logical_token_blocks = copy.deepcopy(
self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
child_seq.data = copy.deepcopy(self.data)
def fork(self, new_seq_id: int) -> "Sequence":
new_seq = copy.deepcopy(self)
new_seq.seq_id = new_seq_id
return new_seq
def __repr__(self) -> str:
return (f"Sequence(seq_id={self.seq_id}, "
@ -203,35 +227,66 @@ class SequenceGroup:
arrival_time: float,
) -> None:
self.request_id = request_id
self.seqs = seqs
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params
self.arrival_time = arrival_time
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if self.sampling_params.use_beam_search:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
return self.sampling_params.best_of
else:
if self.sampling_params.best_of > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences running.
return self.sampling_params.best_of
# At sampling stages, return the number of actual sequences
# running.
return self.num_seqs(status=SequenceStatus.RUNNING)
def get_seqs(
self,
status: Optional[SequenceStatus] = None,
) -> List[Sequence]:
if status is None:
return self.seqs
return list(self.seqs_dict.values())
else:
return [seq for seq in self.seqs if seq.status == status]
return [
seq for seq in self.seqs_dict.values() if seq.status == status
]
def get_finished_seqs(self) -> List[Sequence]:
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status))
def find(self, seq_id: int) -> Sequence:
for seq in self.seqs:
if seq.seq_id == seq_id:
return seq
raise ValueError(f"Sequence {seq_id} not found.")
if seq_id not in self.seqs_dict:
raise ValueError(f"Sequence {seq_id} not found.")
return self.seqs_dict[seq_id]
def add(self, seq: Sequence) -> None:
if seq.seq_id in self.seqs_dict:
raise ValueError(f"Sequence {seq.seq_id} already exists.")
self.seqs_dict[seq.seq_id] = seq
def remove(self, seq_id: int) -> None:
if seq_id not in self.seqs_dict:
raise ValueError(f"Sequence {seq_id} not found.")
del self.seqs_dict[seq_id]
def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.seqs)
return all(seq.is_finished() for seq in self.get_seqs())
def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, "
f"sampling_params={self.sampling_params}, "
f"num_seqs={len(self.seqs)})")
f"num_seqs={len(self.seqs_dict)})")
class SequenceGroupMetadata:
@ -266,7 +321,6 @@ class SequenceOutputs:
"""The model output associated with a sequence.
Args:
seq_id: The ID of the sequence.
parent_seq_id: The ID of the parent sequence (for forking in beam
search).
output_token: The output token ID.
@ -276,26 +330,27 @@ class SequenceOutputs:
def __init__(
self,
seq_id: int,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, float],
) -> None:
self.seq_id = seq_id
self.parent_seq_id = parent_seq_id
self.output_token = output_token
self.logprobs = logprobs
def __repr__(self) -> str:
return (f"SequenceOutputs(seq_id={self.seq_id}, "
f"parent_seq_id={self.parent_seq_id}, "
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}), "
f"logprobs={self.logprobs}")
def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs):
return NotImplemented
return (self.seq_id == other.seq_id
and self.parent_seq_id == other.parent_seq_id
return NotImplementedError()
return (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token
and self.logprobs == other.logprobs)
# For each sequence group, we generate a list of SequenceOutputs object,
# each of which contains one possible candidate for the next token.
SamplerOutput = List[List[SequenceOutputs]]

View File

@ -11,7 +11,7 @@ from vllm.model_executor import get_model, InputMetadata, set_random_seed
from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel)
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.utils import get_gpu_memory
@ -260,7 +260,7 @@ class Worker:
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> Dict[int, SequenceOutputs]:
) -> SamplerOutput:
# Issue cache operations.
issued_cache_op = False
if blocks_to_swap_in: