diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index 4cec70ccf6..e60f4a40d2 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], - +) -> Dict[int, SequenceOutputs]: + +) -> SamplerOutput: 3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors. 4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture. diff --git a/tests/conftest.py b/tests/conftest.py index 92b06f3857..281c9161d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -67,8 +67,8 @@ class HfRunner: output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False, - )[0] - output_ids = output_ids[0].cpu().tolist() + ) + output_ids = output_ids.cpu().tolist() outputs.append((output_ids, output_str)) return outputs @@ -77,8 +77,34 @@ class HfRunner: prompts: List[str], max_tokens: int, ) -> List[Tuple[List[int], str]]: - return self.generate(prompts, do_sample=False, - max_new_tokens=max_tokens) + outputs = self.generate(prompts, + do_sample=False, + max_new_tokens=max_tokens) + for i in range(len(outputs)): + output_ids, output_str = outputs[i] + outputs[i] = (output_ids[0], output_str[0]) + return outputs + + def generate_beam_search( + self, + prompts: List[str], + beam_width: int, + max_tokens: int, + ) -> List[Tuple[List[int], str]]: + outputs = self.generate(prompts, + do_sample=False, + max_new_tokens=max_tokens, + num_beams=beam_width, + num_return_sequences=beam_width) + for i in range(len(outputs)): + output_ids, output_str = outputs[i] + for j in range(len(output_ids)): + output_ids[j] = [ + x for x in output_ids[j] + if x != self.tokenizer.pad_token_id + ] + outputs[i] = (output_ids, output_str) + return outputs @pytest.fixture @@ -107,15 +133,20 @@ class VllmRunner: prompts: List[str], sampling_params: SamplingParams, ) -> List[Tuple[List[int], str]]: - req_outputs = self.model.generate( - prompts, sampling_params=sampling_params) + req_outputs = self.model.generate(prompts, + sampling_params=sampling_params) outputs = [] for req_output in req_outputs: prompt_str = req_output.prompt prompt_ids = req_output.prompt_token_ids - output_str = req_output.outputs[0].text - output_ids = req_output.outputs[0].token_ids - outputs.append((prompt_ids + output_ids, prompt_str + output_str)) + req_sample_output_ids = [] + req_sample_output_strs = [] + for sample in req_output.outputs: + output_str = sample.text + output_ids = sample.token_ids + req_sample_output_ids.append(prompt_ids + output_ids) + req_sample_output_strs.append(prompt_str + output_str) + outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs def generate_greedy( @@ -124,7 +155,22 @@ class VllmRunner: max_tokens: int, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - return self.generate(prompts, greedy_params) + outputs = self.generate(prompts, greedy_params) + return [(output_ids[0], output_str[0]) for output_ids, output_str in + outputs] + + def generate_beam_search( + self, + prompts: List[str], + beam_width: int, + max_tokens: int, + ) -> List[Tuple[List[int], str]]: + beam_search_params = SamplingParams(n=beam_width, + use_beam_search=True, + temperature=0.0, + max_tokens=max_tokens) + outputs = self.generate(prompts, beam_search_params) + return outputs @pytest.fixture diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py new file mode 100644 index 0000000000..a491ffa763 --- /dev/null +++ b/tests/samplers/test_beam_search.py @@ -0,0 +1,46 @@ +"""Compare the outputs of HF and vLLM when using beam search. + +Run `pytest tests/samplers/test_beam_search.py --forked`. +""" +import pytest + +# FIXME(zhuohan): The test can not pass if we: +# 1. Increase max_tokens to 256. +# 2. Increase beam_width to 8. +# 3. Use the model "huggyllama/llama-7b". +MAX_TOKENS = [128] +BEAM_WIDTHS = [4] +MODELS = ["facebook/opt-125m"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", MAX_TOKENS) +@pytest.mark.parametrize("beam_width", BEAM_WIDTHS) +def test_beam_search_single_input( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width, + max_tokens) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype) + vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, + max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, _ = hf_outputs[i] + vllm_output_ids, _ = vllm_outputs[i] + assert len(hf_output_ids) == len(vllm_output_ids) + for j in range(len(hf_output_ids)): + assert hf_output_ids[j] == vllm_output_ids[j], ( + f"Test{i} output{j}:\nHF: {hf_output_ids}\n" + f"vLLM: {vllm_output_ids}") diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 2fea66ad3a..a8262c4722 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -172,9 +172,7 @@ class BlockSpaceManager: def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: # CPU block -> GPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} - for seq in seq_group.get_seqs(): - if seq.is_finished(): - continue + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): new_block_table: BlockTable = [] block_table = self.block_tables[seq.seq_id] @@ -203,9 +201,7 @@ class BlockSpaceManager: def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: # GPU block -> CPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} - for seq in seq_group.get_seqs(): - if seq.is_finished(): - continue + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): new_block_table: BlockTable = [] block_table = self.block_tables[seq.seq_id] diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index fc2335ccdf..2696cf54ac 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -7,8 +7,7 @@ from vllm.core.block_manager import BlockSpaceManager from vllm.core.policy import PolicyFactory from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata, SequenceOutputs, - SequenceStatus) + SequenceGroupMetadata, SequenceStatus) logger = init_logger(__name__) @@ -76,6 +75,7 @@ class Scheduler: num_cpu_blocks=self.cache_config.num_cpu_blocks, ) + # TODO(zhuohan): Use deque instead of list for better performance. # Sequence groups in the WAITING state. self.waiting: List[SequenceGroup] = [] # Sequence groups in the RUNNING state. @@ -96,10 +96,11 @@ class Scheduler: if seq_group.request_id in request_ids: # Remove the sequence group from the state queue. state_queue.remove(seq_group) - for seq in seq_group.seqs: + for seq in seq_group.get_seqs(): if seq.is_finished(): continue - self.free_seq(seq, SequenceStatus.FINISHED_ABORTED) + seq.status = SequenceStatus.FINISHED_ABORTED + self.free_seq(seq) request_ids.remove(seq_group.request_id) if not request_ids: return @@ -123,6 +124,10 @@ class Scheduler: if not self.swapped: ignored_seq_groups: List[SequenceGroup] = [] scheduled: List[SequenceGroup] = [] + # The total number of sequences on the fly, including the + # requests in the generation phase. + num_curr_seqs = sum(seq_group.get_max_num_running_seqs() + for seq_group in self.running) num_batched_tokens = 0 # Optimization: We do not sort the waiting queue since the preempted # sequence groups are added to the front and the new sequence groups @@ -130,6 +135,9 @@ class Scheduler: while self.waiting: seq_group = self.waiting[0] + assert seq_group.num_seqs() == 1, ( + "Waiting sequence group should have only one prompt " + "sequence.") num_prompt_tokens = seq_group.get_seqs()[0].get_len() if num_prompt_tokens > self.prompt_limit: logger.warning( @@ -152,11 +160,7 @@ class Scheduler: # The total number of sequences in the RUNNING state should not # exceed the maximum number of sequences. - num_new_seqs = seq_group.num_seqs( - status=SequenceStatus.WAITING) - num_curr_seqs = sum( - seq_group.num_seqs(status=SequenceStatus.RUNNING) - for seq_group in self.running) + num_new_seqs = seq_group.get_max_num_running_seqs() if (num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs): break @@ -165,6 +169,7 @@ class Scheduler: self._allocate(seq_group) self.running.append(seq_group) num_batched_tokens += num_prompt_tokens + num_curr_seqs += num_new_seqs scheduled.append(seq_group) if scheduled: @@ -210,30 +215,32 @@ class Scheduler: # Swap in the sequence groups in the SWAPPED state if possible. self.swapped = self.policy.sort_by_priority(now, self.swapped) - while self.swapped and not blocks_to_swap_out: - seq_group = self.swapped[0] - # If the sequence group has been preempted in this step, stop. - if seq_group in preempted: - break - # If the sequence group cannot be swapped in, stop. - if not self.block_manager.can_swap_in(seq_group): - break + if not preempted: + num_curr_seqs = sum(seq_group.get_max_num_running_seqs() + for seq_group in self.running) - # The total number of sequences in the RUNNING state should not - # exceed the maximum number of sequences. - num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) - num_curr_seqs = sum( - seq_group.num_seqs(status=SequenceStatus.RUNNING) - for seq_group in self.running) - if (num_curr_seqs + num_new_seqs > - self.scheduler_config.max_num_seqs): - break + while self.swapped: + seq_group = self.swapped[0] + # If the sequence group cannot be swapped in, stop. + if not self.block_manager.can_swap_in(seq_group): + break - seq_group = self.swapped.pop(0) - self._swap_in(seq_group, blocks_to_swap_in) - self._append_slot(seq_group, blocks_to_copy) - self.running.append(seq_group) + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + if (num_curr_seqs + num_new_seqs > + self.scheduler_config.max_num_seqs): + break + seq_group = self.swapped.pop(0) + self._swap_in(seq_group, blocks_to_swap_in) + self._append_slot(seq_group, blocks_to_copy) + num_curr_seqs += num_new_seqs + self.running.append(seq_group) + + # Each sequence in the generation phase only takes one token slot. + # Therefore, the number of batched tokens is equal to the number of + # sequences in the RUNNING state. num_batched_tokens = sum( seq_group.num_seqs(status=SequenceStatus.RUNNING) for seq_group in self.running) @@ -275,40 +282,10 @@ class Scheduler: seq_group_metadata_list.append(seq_group_metadata) return seq_group_metadata_list, scheduler_outputs - def update( - self, - seq_outputs: Dict[int, SequenceOutputs], - ) -> List[SequenceGroup]: - scheduled: List[SequenceGroup] = [] - for seq_group in self.running: - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - if seq.seq_id in seq_outputs: - scheduled.append(seq_group) - break + def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: + self.block_manager.fork(parent_seq, child_seq) - # Update the scheduled sequences and free blocks. - for seq_group in scheduled: - # Process beam search results before processing the new tokens. - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - output = seq_outputs[seq.seq_id] - if seq.seq_id != output.parent_seq_id: - # The sequence is a fork of the parent sequence (beam - # search). Free the current sequence. - self.block_manager.free(seq) - # Fork the parent sequence. - parent_seq = seq_group.find(output.parent_seq_id) - parent_seq.fork(seq) - self.block_manager.fork(parent_seq, seq) - - # Process the new tokens. - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - # Append a new token to the sequence. - output = seq_outputs[seq.seq_id] - seq.append_token_id(output.output_token, output.logprobs) - return scheduled - - def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None: - seq.status = finish_status + def free_seq(self, seq: Sequence) -> None: self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: @@ -345,8 +322,8 @@ class Scheduler: # If preemption mode is not specified, we determine the mode as follows: # We use recomputation by default since it incurs lower overhead than # swapping. However, when the sequence group has multiple sequences - # (e.g., beam search), recomputation is not supported. In such a case, - # we use swapping instead. + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. # FIXME(woosuk): This makes our scheduling policy a bit bizarre. # As swapped sequences are prioritized over waiting sequences, # sequence groups with multiple sequences are implicitly prioritized @@ -354,8 +331,7 @@ class Scheduler: # TODO(woosuk): Support recomputation for sequence groups with multiple # sequences. This may require a more sophisticated CUDA kernel. if preemption_mode is None: - seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - if len(seqs) == 1: + if seq_group.get_max_num_running_seqs() == 1: preemption_mode = PreemptionMode.RECOMPUTE else: preemption_mode = PreemptionMode.SWAP diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 54141bbe55..4ea443d845 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -11,7 +11,8 @@ from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata, +from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, + SequenceGroupMetadata, SequenceOutputs, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, get_tokenizer) @@ -258,14 +259,11 @@ class LLMEngine: # Create the sequences. block_size = self.cache_config.block_size - seqs: List[Sequence] = [] - for _ in range(sampling_params.best_of): - seq_id = next(self.seq_counter) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) - seqs.append(seq) + seq_id = next(self.seq_counter) + seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) # Create the sequence group. - seq_group = SequenceGroup(request_id, seqs, sampling_params, + seq_group = SequenceGroup(request_id, [seq], sampling_params, arrival_time) # Add the sequence group to the scheduler. @@ -303,22 +301,230 @@ class LLMEngine: ] return seq_group_metadata_list, scheduler_outputs, None - def _process_worker_outputs( - self, output, - scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: - # Update the scheduler with the model outputs. - seq_groups = self.scheduler.update(output) + def _check_beam_search_early_stopping( + self, + early_stopping: Union[bool, str], + sampling_params: SamplingParams, + best_running_seq: Sequence, + current_worst_seq: Sequence, + ) -> bool: + assert sampling_params.use_beam_search + length_penalty = sampling_params.length_penalty + if early_stopping is True: + return True + + current_worst_score = (current_worst_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id)) + if early_stopping is False: + highest_attainable_score = (best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id)) + else: + assert early_stopping == "never" + if length_penalty > 0.0: + # If length_penalty > 0.0, beam search will prefer longer + # sequences. The highest attainable score calculation is + # based on the longest possible sequence length in this case. + max_possible_length = max( + best_running_seq.get_prompt_len() + + sampling_params.max_tokens, + self.scheduler_config.max_model_len) + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id, + seq_len=max_possible_length)) + else: + # Otherwise, beam search will prefer shorter sequences. The + # highest attainable score calculation is based on the current + # sequence length. + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id)) + return current_worst_score >= highest_attainable_score + + def _process_sequence_group_samples( + self, seq_group: SequenceGroup, + samples: List[SequenceOutputs]) -> None: + parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + existing_finished_seqs = seq_group.get_finished_seqs() + parent_child_dict = { + parent_seq.seq_id: [] + for parent_seq in parent_seqs + } + for sample in samples: + parent_child_dict[sample.parent_seq_id].append(sample) + # List of (child, parent) + child_seqs: List[Tuple[Sequence, Sequence]] = [] + + # Process the child samples for each parent sequence + for parent in parent_seqs: + child_samples: List[SequenceOutputs] = parent_child_dict[ + parent.seq_id] + if len(child_samples) == 0: + # This parent sequence has no children samples. Remove + # the parent sequence from the sequence group since it will + # not be used in the future iterations. + parent.status = SequenceStatus.FINISHED_ABORTED + seq_group.remove(parent.seq_id) + self.scheduler.free_seq(parent) + continue + # Fork the parent sequence if there are multiple child samples. + for child_sample in child_samples[:-1]: + new_child_seq_id = next(self.seq_counter) + child = parent.fork(new_child_seq_id) + child.append_token_id(child_sample.output_token, + child_sample.logprobs) + child_seqs.append((child, parent)) + # Continue the parent sequence for the last child sample. + # We reuse the parent sequence here to reduce redundant memory + # copies, especially when using non-beam search sampling methods. + last_child_sample = child_samples[-1] + parent.append_token_id(last_child_sample.output_token, + last_child_sample.logprobs) + child_seqs.append((parent, parent)) + + for seq, _ in child_seqs: + self._decode_sequence(seq) + self._check_stop(seq, seq_group.sampling_params) + + # Non-beam search case + if not seq_group.sampling_params.use_beam_search: + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + # NOTE: we need to fork the new sequences before freeing the + # old sequences. + for seq, parent in child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + return + + # Beam search case + # Select the child sequences to keep in the sequence group. + selected_child_seqs = [] + unselected_child_seqs = [] + beam_width = seq_group.sampling_params.best_of + length_penalty = seq_group.sampling_params.length_penalty + + # Select the newly finished sequences with the highest scores + # to replace existing finished sequences. + # Tuple of (seq, parent, is_new) + existing_finished_seqs = [(seq, None, False) + for seq in existing_finished_seqs] + new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs + if seq.is_finished()] + all_finished_seqs = existing_finished_seqs + new_finished_seqs + # Sort the finished sequences by their scores. + all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id), + reverse=True) + for seq, parent, is_new in all_finished_seqs[:beam_width]: + if is_new: + # A newly generated child sequence finishes and has a high + # score, so we will add it into the sequence group. + selected_child_seqs.append((seq, parent)) + for seq, parent, is_new in all_finished_seqs[beam_width:]: + if is_new: + # A newly generated child sequence finishes but has a low + # score, so we will not add it into the sequence group. + # Additionally, if this sequence is a continuation of a + # parent sequence, we will need remove the parent sequence + # from the sequence group. + unselected_child_seqs.append((seq, parent)) + else: + # An existing finished sequence has a low score, so we will + # remove it from the sequence group. + seq_group.remove(seq.seq_id) + + # select the top beam_width sequences from the running + # sequences for the next iteration to continue the beam + # search. + running_child_seqs = [(seq, parent) for seq, parent in child_seqs + if not seq.is_finished()] + # Sort the running sequences by their scores. + running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id), + reverse=True) + + # Check if we can stop the beam search. + if len(running_child_seqs) == 0: + # No running sequences, stop the beam search. + stop_beam_search = True + elif len(all_finished_seqs) < beam_width: + # Not enough finished sequences, continue the beam search. + stop_beam_search = False + else: + # Check the early stopping criteria + best_running_seq = running_child_seqs[0][0] + current_worst_seq = all_finished_seqs[beam_width - 1][0] + stop_beam_search = self._check_beam_search_early_stopping( + seq_group.sampling_params.early_stopping, + seq_group.sampling_params, best_running_seq, current_worst_seq) + + if stop_beam_search: + # Stop the beam search and remove all the running sequences from + # the sequence group. + unselected_child_seqs.extend(running_child_seqs) + else: + # Continue the beam search and select the top beam_width sequences + # to continue the beam search. + selected_child_seqs.extend(running_child_seqs[:beam_width]) + # The remaining running sequences will not be used in the next + # iteration. Again, if these sequences are continuations of + # parent sequences, we will need to remove the parent sequences + # from the sequence group. + unselected_child_seqs.extend(running_child_seqs[beam_width:]) + + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in selected_child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + for seq, parent in selected_child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + + # Remove the unselected parent sequences from the sequence group and + # free their memory in block manager. + for seq, parent in unselected_child_seqs: + if seq is parent: + # Remove the parent sequence if it is not selected for next + # iteration + seq_group.remove(seq.seq_id) + self.scheduler.free_seq(seq) + + def _process_model_outputs( + self, output: SamplerOutput, + scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: + # Update the scheduled sequence groups with the model outputs. + scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups + for seq_group, samples in zip(scheduled_seq_groups, output): + self._process_sequence_group_samples(seq_group, samples) - # Decode the sequences. - self._decode_sequences(seq_groups) - # Stop the sequences that meet the stopping criteria. - self._stop_sequences(seq_groups) # Free the finished sequence groups. self.scheduler.free_finished_seq_groups() # Create the outputs. request_outputs: List[RequestOutput] = [] - for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups: + for seq_group in (scheduled_seq_groups + + scheduler_outputs.ignored_seq_groups): request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) @@ -351,7 +557,7 @@ class LLMEngine: blocks_to_copy=scheduler_outputs.blocks_to_copy, ) - return self._process_worker_outputs(output, scheduler_outputs) + return self._process_model_outputs(output, scheduler_outputs) def _log_system_stats( self, @@ -416,55 +622,44 @@ class LLMEngine: f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") self.last_logging_time = now - def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: - """Decodes the sequence outputs.""" - for seq_group in seq_groups: - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - new_token, new_output_text = detokenize_incrementally( - self.tokenizer, - seq.output_tokens, - seq.get_last_token_id(), - skip_special_tokens=True, - ) - if new_token is not None: - seq.output_tokens.append(new_token) - seq.output_text = new_output_text + def _decode_sequence(self, seq: Sequence) -> None: + """Decodes the new token for a sequence.""" + new_token, new_output_text = detokenize_incrementally( + self.tokenizer, + seq.output_tokens, + seq.get_last_token_id(), + skip_special_tokens=True, + ) + if new_token is not None: + seq.output_tokens.append(new_token) + seq.output_text = new_output_text - def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: + def _check_stop(self, seq: Sequence, + sampling_params: SamplingParams) -> None: """Stop the finished sequences.""" - for seq_group in seq_groups: - sampling_params = seq_group.sampling_params - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - # Check if the sequence has generated a stop string. - stopped = False - for stop_str in sampling_params.stop: - if seq.output_text.endswith(stop_str): - # Truncate the output text so that the stop string is - # not included in the output. - seq.output_text = seq.output_text[:-len(stop_str)] - self.scheduler.free_seq( - seq, SequenceStatus.FINISHED_STOPPED) - stopped = True - break - if stopped: - continue + for stop_str in sampling_params.stop: + if seq.output_text.endswith(stop_str): + # Truncate the output text so that the stop string is + # not included in the output. + seq.output_text = seq.output_text[:-len(stop_str)] + seq.status = SequenceStatus.FINISHED_STOPPED + return - # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: - self.scheduler.free_seq( - seq, SequenceStatus.FINISHED_LENGTH_CAPPED) - continue - # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: - self.scheduler.free_seq( - seq, SequenceStatus.FINISHED_LENGTH_CAPPED) - continue - # Check if the sequence has generated the EOS token. - if not sampling_params.ignore_eos: - if seq.get_last_token_id() == self.tokenizer.eos_token_id: - self.scheduler.free_seq( - seq, SequenceStatus.FINISHED_STOPPED) - continue + # Check if the sequence has reached max_model_len. + if seq.get_len() > self.scheduler_config.max_model_len: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == self.tokenizer.eos_token_id): + seq.status = SequenceStatus.FINISHED_STOPPED + return def _run_workers( self, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 93db03add4..9dcfa42e2a 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -9,7 +9,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.parallel_utils.tensor_parallel import ( gather_from_tensor_model_parallel_region) from vllm.sampling_params import SamplingParams -from vllm.sequence import SequenceOutputs +from vllm.sequence import SamplerOutput, SequenceOutputs _SAMPLING_EPS = 1e-5 @@ -39,7 +39,7 @@ class Sampler(nn.Module): hidden_states: torch.Tensor, input_metadata: InputMetadata, embedding_bias: Optional[torch.Tensor] = None, - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: # Get the hidden states that we use for sampling. hidden_states = _prune_hidden_states(hidden_states, input_metadata) @@ -292,7 +292,13 @@ def _sample_from_prompt( if sampling_params.use_beam_search: # Beam search. beam_width = sampling_params.best_of - _, next_token_ids = torch.topk(prob, beam_width) + # Sample 2 * beam_width candidates to make sure that with high + # probability we can get `beam_width` candidates in addition to + # the finished sequences for the next iteration. See + # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563 + # for details. See also HF reference: + # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 + _, next_token_ids = torch.topk(prob, 2 * beam_width) next_token_ids = next_token_ids.tolist() elif sampling_params.temperature < _SAMPLING_EPS: # Greedy sampling. @@ -330,29 +336,11 @@ def _sample_from_generation_tokens( vocab_size = logprobs.size(-1) beam_width = len(seq_ids) - _, topk_ids = torch.topk(logprobs.flatten(), beam_width) + _, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width) topk_ids = topk_ids.tolist() seq_idx = [i // vocab_size for i in topk_ids] - beam_seq_ids = [seq_ids[i] for i in seq_idx] - token_ids = [i % vocab_size for i in topk_ids] - - beam_outputs: Dict[int, Tuple[int, int]] = {} - outstanding_beams: List[Tuple[int, int]] = [] - # If a beam survives, continue with it. - for seq_id, token_id in zip(beam_seq_ids, token_ids): - if seq_id not in beam_outputs: - beam_outputs[seq_id] = (seq_id, token_id) - else: - outstanding_beams.append((seq_id, token_id)) - - # If a beam is discarded, fork another beam. - for seq_id in seq_ids: - if seq_id not in beam_outputs: - beam_outputs[seq_id] = outstanding_beams.pop() - assert not outstanding_beams - - parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids] - next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids] + parent_seq_ids = [seq_ids[i] for i in seq_idx] + next_token_ids = [i % vocab_size for i in topk_ids] elif sampling_params.temperature < _SAMPLING_EPS: # Greedy sampling. assert len(seq_ids) == 1 @@ -374,16 +362,18 @@ def _sample( probs: torch.Tensor, logprobs: torch.Tensor, input_metadata: InputMetadata, -) -> Dict[int, SequenceOutputs]: - seq_outputs: Dict[int, SequenceOutputs] = {} +) -> SamplerOutput: + seq_outputs: SamplerOutput = [] # TODO(woosuk): Optimize. idx = 0 for i, seq_group in enumerate(input_metadata.seq_groups): + seq_group_outputs: List[SequenceOutputs] = [] seq_ids, sampling_params = seq_group if i < input_metadata.num_prompts: # Generate the next tokens for a prompt input. - assert len(seq_ids) == sampling_params.best_of + assert len(seq_ids) == 1, "Prompt input should have only one seq." + parent_seq_id = seq_ids[0] prob = probs[idx] logprob = logprobs[idx] idx += 1 @@ -395,17 +385,18 @@ def _sample( sampling_params.logprobs) # Build the output. - for seq_id, next_token_id in zip(seq_ids, next_token_ids): + for next_token_id in next_token_ids: output_logprobs = next_logprobs.copy() output_logprobs[next_token_id] = logprob[next_token_id].item() - seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id, - next_token_id, - output_logprobs) + seq_group_outputs.append( + SequenceOutputs(parent_seq_id, next_token_id, + output_logprobs)) else: # Generate the next tokens for generation tokens. - prob = probs[idx:idx + len(seq_ids)] - logprob = logprobs[idx:idx + len(seq_ids)] - idx += len(seq_ids) + num_parent_seqs = len(seq_ids) + prob = probs[idx:idx + num_parent_seqs] + logprob = logprobs[idx:idx + num_parent_seqs] + idx += num_parent_seqs # Sample the next tokens. seq_logprobs = [ @@ -422,17 +413,15 @@ def _sample( logprob[j], sampling_params.logprobs) # Build the output. - for seq_id, parent_seq_id, next_token_id in zip( - seq_ids, parent_seq_ids, next_token_ids): + for parent_seq_id, next_token_id in zip(parent_seq_ids, + next_token_ids): j = seq_ids.index(parent_seq_id) output_logprobs = next_logprobs[parent_seq_id].copy() output_logprobs[next_token_id] = logprob[j, next_token_id].item() - seq_outputs[seq_id] = SequenceOutputs( - seq_id, - parent_seq_id, - next_token_id, - output_logprobs, - ) + seq_group_outputs.append( + SequenceOutputs(parent_seq_id, next_token_id, + output_logprobs)) + seq_outputs.append(seq_group_outputs) return seq_outputs diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 194da4ff54..189a31d36c 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -25,7 +25,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -41,7 +41,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) -from vllm.sequence import SequenceOutputs +from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.aquila import AquilaConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -273,7 +273,7 @@ class AquilaForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head.weight, hidden_states, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 47c79059b7..9ebe37d66a 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -23,12 +23,11 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ import math -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn -from vllm.sequence import SequenceOutputs from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -42,6 +41,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.baichuan import BaiChuanConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -290,7 +290,7 @@ class BaiChuanBaseForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head.weight, hidden_states, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 4a3de8d469..c7e3ecf15b 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -21,7 +21,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ import math -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) -from vllm.sequence import SequenceOutputs +from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -264,7 +264,7 @@ class BloomForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head_weight, hidden_states, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 7730b23189..8a2a9fedff 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -19,7 +19,7 @@ """PyTorch Falcon model.""" import math -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -38,7 +38,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.tensor_parallel import ( VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear, reduce_from_tensor_model_parallel_region) -from vllm.sequence import SequenceOutputs +from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import RWConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -397,7 +397,7 @@ class FalconForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.transformer( input_ids, positions, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index a3b6efe2af..18a11b4907 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -21,7 +21,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -38,7 +38,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) -from vllm.sequence import SequenceOutputs +from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -218,7 +218,7 @@ class GPT2LMHeadModel(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head_weight, hidden_states, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 8694c97583..91c432bf8b 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -22,7 +22,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -39,7 +39,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) -from vllm.sequence import SequenceOutputs +from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -246,7 +246,7 @@ class GPTBigCodeForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head_weight, hidden_states, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index cf89e28bda..35a1518ec8 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -20,7 +20,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) -from vllm.sequence import SequenceOutputs +from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -203,7 +203,7 @@ class GPTJForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head.weight, hidden_states, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index de25029d9e..5319839899 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -20,7 +20,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) -from vllm.sequence import SequenceOutputs +from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -215,7 +215,7 @@ class GPTNeoXForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.gpt_neox(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.embed_out.weight, hidden_states, diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 1aeeb91d94..50b26fcd36 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -17,7 +17,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.weight_utils import ( hf_model_weights_iterator, load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) -from vllm.sequence import SequenceOutputs +from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -218,7 +218,7 @@ class InternLMForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head.weight, hidden_states, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d72c4ff6a0..62da553218 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -25,7 +25,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -43,7 +43,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) -from vllm.sequence import SequenceOutputs +from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -256,7 +256,7 @@ class LlamaForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head.weight, hidden_states, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 5afcb7a0ae..7a75fac62b 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -1,7 +1,7 @@ # coding=utf-8 # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch import torch.nn as nn @@ -16,7 +16,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) -from vllm.sequence import SequenceOutputs +from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.mpt import MPTConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -230,7 +230,7 @@ class MPTForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head_weight, hidden_states, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index b8d6bdd424..9bd503ae42 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -21,7 +21,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) -from vllm.sequence import SequenceOutputs +from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -282,7 +282,7 @@ class OPTForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head_weight, hidden_states, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 1318bbe702..d511ef96b6 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -8,7 +8,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -32,7 +32,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import ( ColumnParallelLinear, RowParallelLinear, ) -from vllm.sequence import SequenceOutputs +from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.qwen import QWenConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -235,7 +235,7 @@ class QWenLMHeadModel(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head.weight, hidden_states, diff --git a/vllm/outputs.py b/vllm/outputs.py index d453b94fe1..64ba8440e3 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -75,10 +75,12 @@ class RequestOutput: # Get the top-n sequences. n = seq_group.sampling_params.n seqs = seq_group.get_seqs() - assert n <= len(seqs) - sorted_seqs = sorted(seqs, - key=lambda seq: seq.get_cumulative_logprob(), - reverse=True) + if seq_group.sampling_params.use_beam_search: + sorting_key = lambda seq: seq.get_beam_search_score( + seq_group.sampling_params.length_penalty) + else: + sorting_key = lambda seq: seq.get_cumulative_logprob() + sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) top_n_seqs = sorted_seqs[:n] # Create the outputs. diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 91f2cb1bf1..d24404b8d3 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -34,6 +34,15 @@ class SamplingParams: top_k: Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens. use_beam_search: Whether to use beam search instead of sampling. + length_penalty: Float that penalizes sequences based on their length. + Used in beam search. + early_stopping: Controls the stopping condition for beam search. It + accepts the following values: `True`, where the generation stops as + soon as there are `best_of` complete candidates; `False`, where an + heuristic is applied and the generation stops when is it very + unlikely to find better candidates; `"never"`, where the beam search + procedure only stops when there cannot be better candidates + (canonical beam search algorithm). stop: List of strings that stop the generation when they are generated. The returned output will not contain the stop strings. ignore_eos: Whether to ignore the EOS token and continue generating @@ -52,6 +61,8 @@ class SamplingParams: top_p: float = 1.0, top_k: int = -1, use_beam_search: bool = False, + length_penalty: float = 1.0, + early_stopping: Union[bool, str] = False, stop: Union[None, str, List[str]] = None, ignore_eos: bool = False, max_tokens: int = 16, @@ -65,6 +76,8 @@ class SamplingParams: self.top_p = top_p self.top_k = top_k self.use_beam_search = use_beam_search + self.length_penalty = length_penalty + self.early_stopping = early_stopping if stop is None: self.stop = [] elif isinstance(stop, str): @@ -78,9 +91,11 @@ class SamplingParams: self._verify_args() if self.use_beam_search: self._verify_beam_search() - elif self.temperature < _SAMPLING_EPS: - # Zero temperature means greedy sampling. - self._verify_greedy_sampling() + else: + self._verify_non_beam_search() + if self.temperature < _SAMPLING_EPS: + # Zero temperature means greedy sampling. + self._verify_greedy_sampling() def _verify_args(self) -> None: if self.n < 1: @@ -119,6 +134,20 @@ class SamplingParams: raise ValueError("top_p must be 1 when using beam search.") if self.top_k != -1: raise ValueError("top_k must be -1 when using beam search.") + if self.early_stopping not in [True, False, "never"]: + raise ValueError( + f"early_stopping must be True, False, or 'never', " + f"got {self.early_stopping}.") + + def _verify_non_beam_search(self) -> None: + if self.early_stopping is not False: + raise ValueError("early_stopping is not effective and must be " + "False when not using beam search.") + if (self.length_penalty < 1.0 - _SAMPLING_EPS + or self.length_penalty > 1.0 + _SAMPLING_EPS): + raise ValueError( + "length_penalty is not effective and must be the " + "default value of 1.0 when not using beam search.") def _verify_greedy_sampling(self) -> None: if self.best_of > 1: @@ -138,6 +167,8 @@ class SamplingParams: f"top_p={self.top_p}, " f"top_k={self.top_k}, " f"use_beam_search={self.use_beam_search}, " + f"length_penalty={self.length_penalty}, " + f"early_stopping={self.early_stopping}, " f"stop={self.stop}, " f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, " diff --git a/vllm/sequence.py b/vllm/sequence.py index 87c80c5dcd..74682786da 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -69,6 +69,9 @@ class SequenceData: def get_len(self) -> int: return len(self.output_token_ids) + len(self.prompt_token_ids) + def get_prompt_len(self) -> int: + return len(self.prompt_token_ids) + def get_output_len(self) -> int: return len(self.output_token_ids) @@ -155,6 +158,9 @@ class Sequence: def get_len(self) -> int: return self.data.get_len() + def get_prompt_len(self) -> int: + return self.data.get_prompt_len() + def get_output_len(self) -> int: return self.data.get_output_len() @@ -170,14 +176,32 @@ class Sequence: def get_cumulative_logprob(self) -> float: return self.data.cumulative_logprob + def get_beam_search_score(self, + length_penalty: float = 0.0, + seq_len: Optional[int] = None, + eos_token_id: Optional[int] = None) -> float: + """Calculate the beam search score with length penalty. + + Adapted from + + https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 + """ + if seq_len is None: + seq_len = self.get_len() + # Note: HF implementation does not count the EOS token + # towards the length, we align with that here for testing. + if (eos_token_id is not None + and self.get_last_token_id() == eos_token_id): + seq_len -= 1 + return self.get_cumulative_logprob() / (seq_len**length_penalty) + def is_finished(self) -> bool: return SequenceStatus.is_finished(self.status) - def fork(self, child_seq: "Sequence") -> None: - child_seq.logical_token_blocks = copy.deepcopy( - self.logical_token_blocks) - child_seq.output_logprobs = copy.deepcopy(self.output_logprobs) - child_seq.data = copy.deepcopy(self.data) + def fork(self, new_seq_id: int) -> "Sequence": + new_seq = copy.deepcopy(self) + new_seq.seq_id = new_seq_id + return new_seq def __repr__(self) -> str: return (f"Sequence(seq_id={self.seq_id}, " @@ -203,35 +227,66 @@ class SequenceGroup: arrival_time: float, ) -> None: self.request_id = request_id - self.seqs = seqs + self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.arrival_time = arrival_time + def get_max_num_running_seqs(self) -> int: + """The maximum number of sequences running in parallel in the remaining + lifetime of the request.""" + if self.sampling_params.use_beam_search: + # For beam search, maximally there will always be `best_of` beam + # candidates running in the future. + return self.sampling_params.best_of + else: + if self.sampling_params.best_of > self.num_seqs(): + # At prompt stage, the sequence group is not yet filled up + # and only have one sequence running. However, in the + # generation stage, we will have `best_of` sequences running. + return self.sampling_params.best_of + # At sampling stages, return the number of actual sequences + # running. + return self.num_seqs(status=SequenceStatus.RUNNING) + def get_seqs( self, status: Optional[SequenceStatus] = None, ) -> List[Sequence]: if status is None: - return self.seqs + return list(self.seqs_dict.values()) else: - return [seq for seq in self.seqs if seq.status == status] + return [ + seq for seq in self.seqs_dict.values() if seq.status == status + ] + + def get_finished_seqs(self) -> List[Sequence]: + return [seq for seq in self.seqs_dict.values() if seq.is_finished()] def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: return len(self.get_seqs(status)) def find(self, seq_id: int) -> Sequence: - for seq in self.seqs: - if seq.seq_id == seq_id: - return seq - raise ValueError(f"Sequence {seq_id} not found.") + if seq_id not in self.seqs_dict: + raise ValueError(f"Sequence {seq_id} not found.") + return self.seqs_dict[seq_id] + + def add(self, seq: Sequence) -> None: + if seq.seq_id in self.seqs_dict: + raise ValueError(f"Sequence {seq.seq_id} already exists.") + self.seqs_dict[seq.seq_id] = seq + + def remove(self, seq_id: int) -> None: + if seq_id not in self.seqs_dict: + raise ValueError(f"Sequence {seq_id} not found.") + del self.seqs_dict[seq_id] def is_finished(self) -> bool: - return all(seq.is_finished() for seq in self.seqs) + return all(seq.is_finished() for seq in self.get_seqs()) def __repr__(self) -> str: return (f"SequenceGroup(request_id={self.request_id}, " f"sampling_params={self.sampling_params}, " - f"num_seqs={len(self.seqs)})") + f"num_seqs={len(self.seqs_dict)})") class SequenceGroupMetadata: @@ -266,7 +321,6 @@ class SequenceOutputs: """The model output associated with a sequence. Args: - seq_id: The ID of the sequence. parent_seq_id: The ID of the parent sequence (for forking in beam search). output_token: The output token ID. @@ -276,26 +330,27 @@ class SequenceOutputs: def __init__( self, - seq_id: int, parent_seq_id: int, output_token: int, logprobs: Dict[int, float], ) -> None: - self.seq_id = seq_id self.parent_seq_id = parent_seq_id self.output_token = output_token self.logprobs = logprobs def __repr__(self) -> str: - return (f"SequenceOutputs(seq_id={self.seq_id}, " - f"parent_seq_id={self.parent_seq_id}, " + return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, " f"output_token={self.output_token}), " f"logprobs={self.logprobs}") def __eq__(self, other: object) -> bool: if not isinstance(other, SequenceOutputs): - return NotImplemented - return (self.seq_id == other.seq_id - and self.parent_seq_id == other.parent_seq_id + return NotImplementedError() + return (self.parent_seq_id == other.parent_seq_id and self.output_token == other.output_token and self.logprobs == other.logprobs) + + +# For each sequence group, we generate a list of SequenceOutputs object, +# each of which contains one possible candidate for the next token. +SamplerOutput = List[List[SequenceOutputs]] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index bee7d441c6..2d2021d9fe 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -11,7 +11,7 @@ from vllm.model_executor import get_model, InputMetadata, set_random_seed from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) from vllm.sampling_params import SamplingParams -from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.utils import get_gpu_memory @@ -260,7 +260,7 @@ class Worker: blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: # Issue cache operations. issued_cache_op = False if blocks_to_swap_in: