mirror of https://github.com/vllm-project/vllm
[Core] Refactoring sampler and support prompt logprob for chunked prefill (#4309)
This commit is contained in:
parent
a88081bf76
commit
603ad84815
|
@ -9,15 +9,26 @@ MODELS = ["facebook/opt-125m"]
|
|||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
|
||||
def test_get_prompt_logprobs(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
chunked_prefill_token_size: int,
|
||||
num_top_logprobs: int,
|
||||
example_prompts,
|
||||
):
|
||||
max_num_seqs = 256
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
max_tokens = 5
|
||||
num_top_logprobs = 6
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_logprobs = hf_model.generate_greedy_logprobs(
|
||||
example_prompts,
|
||||
|
@ -25,10 +36,17 @@ def test_get_prompt_logprobs(
|
|||
)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_logprobs=num_top_logprobs,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs,
|
||||
)
|
||||
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=5,
|
||||
prompt_logprobs=num_top_logprobs,
|
||||
temperature=0.0)
|
||||
vllm_results = vllm_model.model.generate(
|
||||
example_prompts, sampling_params=vllm_sampling_params)
|
||||
|
@ -52,9 +70,18 @@ def test_get_prompt_logprobs(
|
|||
"The output text from the top logprob for each token position "
|
||||
"should be the same as the output text in the result.")
|
||||
|
||||
# The first prompt logprob is always None
|
||||
assert result.prompt_logprobs[0] is None
|
||||
for prompt_logprobs in result.prompt_logprobs[1:]:
|
||||
# If the prompt token is not included in the top X
|
||||
# logprob, it can return 1 more data
|
||||
assert (len(prompt_logprobs) == num_top_logprobs
|
||||
or len(prompt_logprobs) == num_top_logprobs + 1)
|
||||
|
||||
# Test whether prompt logprobs are consistent with HF
|
||||
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
|
||||
# Check prompt logprobs
|
||||
# The first prompt logprob is always None, so we compare it from 1:.
|
||||
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
|
||||
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
|
||||
for token_id, logprob in vllm_prompt_logprob_dict.items():
|
||||
|
@ -74,6 +101,17 @@ def test_get_prompt_logprobs(
|
|||
"The token should be decoded by the time it is returned "
|
||||
" to the user.")
|
||||
|
||||
# Test if prompt logprobs are correctly set.
|
||||
for vllm_result in vllm_results:
|
||||
token_ids = vllm_result.prompt_token_ids
|
||||
prompt_logprobs = vllm_result.prompt_logprobs
|
||||
|
||||
# The first token doesn't have logprob.
|
||||
assert prompt_logprobs[0] is None
|
||||
|
||||
for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
|
||||
assert token_id in logprob_dict
|
||||
|
||||
|
||||
def test_max_logprobs():
|
||||
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch
|
|||
from transformers import GenerationConfig, GenerationMixin
|
||||
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import Counter
|
||||
|
@ -54,6 +55,7 @@ def _do_sample(
|
|||
sampler: MockLogitsSampler,
|
||||
model_runner: ModelRunner,
|
||||
sampling_params: SamplingParams,
|
||||
device: str,
|
||||
):
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
|
@ -68,9 +70,12 @@ def _do_sample(
|
|||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
|
||||
|
||||
|
||||
|
@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str):
|
|||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||
sampling_params)
|
||||
sampling_params, device)
|
||||
expected = torch.argmax(fake_logits, dim=-1)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
|
@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str):
|
|||
n=random.randint(1, 10),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||
sampling_params)
|
||||
sampling_params, device)
|
||||
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
|
@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
|
|||
seed=random.randint(0, 10000),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||
sampling_params)
|
||||
sampling_params, device)
|
||||
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
|
@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
|
|||
seed=random.randint(0, 10000),
|
||||
)
|
||||
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
model_runner, sampling_params)
|
||||
model_runner, sampling_params, device)
|
||||
|
||||
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
model_runner, sampling_params)
|
||||
model_runner, sampling_params, device)
|
||||
|
||||
assert first_sampler_output == second_sampler_output
|
||||
|
||||
|
@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str):
|
|||
best_of=2,
|
||||
use_beam_search=True,
|
||||
)
|
||||
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
|
||||
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params,
|
||||
device)
|
||||
# no assertion here as I am not sure how to determine whether
|
||||
# the outputs are expected - in other words, this just tests
|
||||
# whether there are no exceptions in the sampler
|
||||
|
@ -443,10 +449,12 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
|||
"batch size")
|
||||
|
||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||
sampling_metadata = model_runner._prepare_sample(
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens=prompt_lens if prompt_lens else None,
|
||||
subquery_lens=prompt_lens if prompt_lens else None)
|
||||
subquery_lens=prompt_lens if prompt_lens else None,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
# the logits tensor is modified in-place by the sampler
|
||||
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
|
@ -530,8 +538,12 @@ def test_sampler_mixed(seed: int, device: str):
|
|||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
def test_sampling(model_runner: ModelRunner):
|
||||
sampling_metadata = model_runner._prepare_sample(
|
||||
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
sampler_output = sampler(logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
|
@ -627,9 +639,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
|||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
|
||||
sample_probs = None
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
|
@ -82,9 +83,12 @@ def test_logits_processors(seed: int, device: str):
|
|||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
logits_processor_output = logits_processor(
|
||||
embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
|
|
|
@ -2,6 +2,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
||||
|
||||
|
@ -97,9 +98,12 @@ def test_prepare_prompt(batch_size):
|
|||
assert len(input_positions) == sum(prompt_lens)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
assert len(input_tokens) == sum(prompt_lens)
|
||||
assert len(input_positions) == sum(prompt_lens)
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
|
@ -195,9 +199,12 @@ def test_prepare_decode_cuda_graph(batch_size):
|
|||
for prompt_len in prompt_lens:
|
||||
expected_selected_token_indices.append(selected_token_start_idx)
|
||||
selected_token_start_idx += 1
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
|
|
|
@ -915,6 +915,20 @@ class Scheduler:
|
|||
self.block_manager.get_common_computed_block_ids(
|
||||
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
|
||||
|
||||
do_sample = True
|
||||
if seq_group.is_prefill():
|
||||
seqs = seq_group.get_seqs()
|
||||
# Prefill has only 1 sequence.
|
||||
assert len(seqs) == 1
|
||||
# In the next iteration, all prompt tokens are not computed.
|
||||
# It means the prefill is chunked, and we don't need sampling.
|
||||
# NOTE: We use get_len instead of get_prompt_len because when
|
||||
# a sequence is preempted, prefill includes previous generated
|
||||
# output tokens.
|
||||
if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
|
||||
seqs[0].data.get_len()):
|
||||
do_sample = False
|
||||
|
||||
# It assumes the scheduled_seq_groups is ordered by
|
||||
# prefill < decoding.
|
||||
is_prompt = seq_group.is_prefill()
|
||||
|
@ -924,6 +938,7 @@ class Scheduler:
|
|||
seq_data=seq_data,
|
||||
sampling_params=seq_group.sampling_params,
|
||||
block_tables=block_tables,
|
||||
do_sample=do_sample,
|
||||
token_chunk_size=token_chunk_size,
|
||||
lora_request=seq_group.lora_request,
|
||||
computed_block_nums=common_computed_block_nums,
|
||||
|
|
|
@ -219,7 +219,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||
|
||||
request_outputs = self._process_model_outputs(
|
||||
output, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups)
|
||||
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||
|
||||
# Log stats.
|
||||
if self.log_stats:
|
||||
|
|
|
@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
|
|||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
||||
SequenceGroup, SequenceStage)
|
||||
SequenceGroup, SequenceGroupMetadata)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
|
@ -476,9 +476,12 @@ class LLMEngine:
|
|||
return self.scheduler.has_unfinished_seqs()
|
||||
|
||||
def _process_model_outputs(
|
||||
self, output: List[SamplerOutput],
|
||||
scheduled_seq_groups: List[SequenceGroup],
|
||||
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
|
||||
self,
|
||||
output: List[SamplerOutput],
|
||||
scheduled_seq_groups: List[SequenceGroup],
|
||||
ignored_seq_groups: List[SequenceGroup],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> List[RequestOutput]:
|
||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||
|
||||
Returns RequestOutputs that can be returned to the client.
|
||||
|
@ -492,17 +495,15 @@ class LLMEngine:
|
|||
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
|
||||
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
|
||||
output_by_sequence_group):
|
||||
for scheduled_seq_group, outputs, seq_group_meta in zip(
|
||||
scheduled_seq_groups, output_by_sequence_group,
|
||||
seq_group_metadata_list):
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
|
||||
# If all sequences in the sequence group are in DECODE, then we can
|
||||
# process the output tokens. Otherwise, they are (chunked) prefill
|
||||
# samples and should not be processed.
|
||||
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
|
||||
if all(stage == SequenceStage.DECODE for stage in stages):
|
||||
self.output_processor.process_prompt_logprob(seq_group, outputs)
|
||||
if seq_group_meta.do_sample:
|
||||
self.output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
|
@ -585,7 +586,7 @@ class LLMEngine:
|
|||
|
||||
request_outputs = self._process_model_outputs(
|
||||
output, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups)
|
||||
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||
|
||||
# Log stats.
|
||||
if self.log_stats:
|
||||
|
|
|
@ -68,3 +68,9 @@ class SequenceGroupOutputProcessor(ABC):
|
|||
scheduler.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
"""Update prompt logprobs received from outputs to seq_group."""
|
||||
pass
|
||||
|
|
|
@ -44,6 +44,15 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
||||
self.stop_checker = stop_checker
|
||||
|
||||
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
# TODO(sang): Prompt logprob currently not implemented in multi step
|
||||
# workers.
|
||||
logger.warning(
|
||||
"Prompt logprob is not supported by multi step workers. "
|
||||
"(e.g., speculative decode uses multi step workers).")
|
||||
pass
|
||||
|
||||
def process_outputs(self, sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
"""Append new tokens in the outputs to sequences in the sequence group.
|
||||
|
|
|
@ -55,17 +55,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||
), f"{type(self)} does not support multiple outputs per step"
|
||||
return self._process_sequence_group_outputs(sequence_group, outputs[0])
|
||||
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
outputs: SequenceGroupOutput) -> None:
|
||||
|
||||
# Process prompt logprobs
|
||||
prompt_logprobs = outputs.prompt_logprobs
|
||||
if prompt_logprobs is not None and \
|
||||
seq_group.sampling_params.detokenize and self.detokenizer:
|
||||
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
assert len(outputs) == 1, ("Single step should only has 1 output.")
|
||||
output = outputs[0]
|
||||
prompt_logprobs = output.prompt_logprobs
|
||||
if (prompt_logprobs is not None
|
||||
and seq_group.sampling_params.detokenize and self.detokenizer):
|
||||
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||
seq_group, prompt_logprobs)
|
||||
seq_group.prompt_logprobs = prompt_logprobs
|
||||
if not seq_group.prompt_logprobs:
|
||||
# The first prompt token's logprob is None because it doesn't
|
||||
# have tokens that are precedent.
|
||||
seq_group.prompt_logprobs = [None]
|
||||
seq_group.prompt_logprobs.extend(prompt_logprobs)
|
||||
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
outputs: SequenceGroupOutput) -> None:
|
||||
# Process samples
|
||||
samples = outputs.samples
|
||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from typing import List
|
||||
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupOutput
|
||||
|
||||
|
||||
def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
|
||||
num_seq_groups: int):
|
||||
def create_output_by_sequence_group(
|
||||
sampler_outputs: List[SamplerOutput],
|
||||
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
|
||||
"""Helper method which transforms a 2d list organized by
|
||||
[step][sequence group] into [sequence group][step].
|
||||
"""
|
||||
|
|
|
@ -83,30 +83,27 @@ def _apply_logits_processors(
|
|||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
logits_row_idx = 0
|
||||
found_logits_processors = False
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
logits_processed = 0
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
sampling_params = seq_group.sampling_params
|
||||
logits_processors = sampling_params.logits_processors
|
||||
# handle prompt_logprobs by skipping rows in logits added for
|
||||
# the prompt tokens (prompt logprobs are not processed)
|
||||
if (i < sampling_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
assert len(seq_ids) == 1
|
||||
logits_row_idx += sampling_metadata.prompt_lens[i] - 1
|
||||
|
||||
if logits_processors:
|
||||
found_logits_processors = True
|
||||
for seq_id in seq_ids:
|
||||
for seq_id, logits_row_idx in zip(seq_ids,
|
||||
seq_group.sample_indices):
|
||||
logits_row = logits[logits_row_idx]
|
||||
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
|
||||
token_ids = seq_group.seq_data[seq_id].output_token_ids
|
||||
for logits_processor in logits_processors:
|
||||
logits_row = logits_processor(token_ids, logits_row)
|
||||
logits[logits_row_idx] = logits_row
|
||||
logits_row_idx += 1
|
||||
else:
|
||||
logits_row_idx += len(seq_ids)
|
||||
|
||||
logits_processed += len(seq_group.sample_indices) + len(
|
||||
seq_group.prompt_logprob_indices)
|
||||
|
||||
if found_logits_processors:
|
||||
# verifies that no rows in logits were missed unexpectedly
|
||||
assert logits_row_idx == logits.shape[0]
|
||||
assert logits_processed == logits.shape[0]
|
||||
return logits
|
||||
|
|
|
@ -7,11 +7,11 @@ import torch.nn as nn
|
|||
|
||||
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||
SamplingTensors)
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
SamplingTensors,
|
||||
SequenceGroupToSample)
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
||||
SamplerOutput, SequenceData, SequenceGroupOutput,
|
||||
SequenceOutput)
|
||||
SamplerOutput, SequenceGroupOutput, SequenceOutput)
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
@ -48,11 +48,14 @@ class Sampler(nn.Module):
|
|||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
"""
|
||||
Args:
|
||||
logits: (num_tokens, vocab_size).
|
||||
sampling_metadata: Metadata for sampling.
|
||||
"""
|
||||
assert logits is not None
|
||||
_, vocab_size = logits.shape
|
||||
|
||||
# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
|
||||
# have not been generated yet
|
||||
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
||||
|
||||
# Prepare sampling tensors with pinned memory to avoid blocking.
|
||||
|
@ -83,7 +86,6 @@ class Sampler(nn.Module):
|
|||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities.
|
||||
# Use log_softmax to ensure numerical stability.
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
|
@ -149,24 +151,28 @@ def _apply_min_tokens_penalty(
|
|||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
|
||||
have not been generated yet
|
||||
"""
|
||||
# list of indices in logits that will be set to -inf
|
||||
logits_to_penalize = []
|
||||
start_idx = 0
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
logits_applied = 0
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
sampling_params = seq_group.sampling_params
|
||||
|
||||
# handle prompt_logprobs by skipping rows in logits added for the prompt
|
||||
# tokens (prompt logprobs are not penalized)
|
||||
if (i < sampling_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
assert len(seq_ids) == 1
|
||||
start_idx += sampling_metadata.prompt_lens[i] - 1
|
||||
sample_indices = seq_group.sample_indices
|
||||
logits_applied += len(sample_indices) + len(
|
||||
seq_group.prompt_logprob_indices)
|
||||
if not seq_group.do_sample:
|
||||
continue
|
||||
|
||||
start_idx = sample_indices[0]
|
||||
min_tokens = sampling_params.min_tokens
|
||||
if min_tokens > 0:
|
||||
seqs_to_penalize = []
|
||||
for i, seq_id in enumerate(seq_ids):
|
||||
seq_data = sampling_metadata.seq_data[seq_id]
|
||||
seq_data = seq_group.seq_data[seq_id]
|
||||
if len(seq_data.output_token_ids) < min_tokens:
|
||||
seqs_to_penalize.append(i)
|
||||
|
||||
|
@ -180,15 +186,13 @@ def _apply_min_tokens_penalty(
|
|||
logits_to_penalize.extend(
|
||||
itertools.product(seqs_to_penalize, token_ids_to_penalize))
|
||||
|
||||
start_idx += len(seq_ids)
|
||||
|
||||
if logits_to_penalize:
|
||||
# use zip and * to group indices along each dimension
|
||||
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
|
||||
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
|
||||
|
||||
# verifies that no rows in logits were missed unexpectedly
|
||||
assert start_idx == logits.shape[0]
|
||||
assert logits_applied == logits.shape[0]
|
||||
return logits
|
||||
|
||||
|
||||
|
@ -265,14 +269,30 @@ def _apply_min_p(
|
|||
|
||||
|
||||
def _greedy_sample(
|
||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
selected_seq_groups: List[SequenceGroupToSample],
|
||||
samples: torch.Tensor,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
"""Run greedy sampling on a given samples.
|
||||
|
||||
Args:
|
||||
selected_seq_groups: A list of sequence groups batched.
|
||||
samples: (num_selected_samples,) A tensor of samples. The length of
|
||||
samples could be smaller than selected_seq_groups if
|
||||
seq_group.do_sample is False.
|
||||
Returns:
|
||||
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
||||
same as the length of selected_seq_groups. If the corresponding
|
||||
seq_group has do_sample=False, tuple contains ([], [])
|
||||
"""
|
||||
samples = samples.tolist()
|
||||
sample_idx = 0
|
||||
results = []
|
||||
for seq_group in selected_seq_groups:
|
||||
seq_ids, _ = seq_group
|
||||
if not seq_group.do_sample:
|
||||
results.append(([], []))
|
||||
continue
|
||||
|
||||
seq_ids = seq_group.seq_ids
|
||||
num_parent_seqs = len(seq_ids)
|
||||
assert num_parent_seqs == 1, (
|
||||
"Greedy sampling should have only one seq.")
|
||||
|
@ -284,16 +304,33 @@ def _greedy_sample(
|
|||
|
||||
|
||||
def _random_sample(
|
||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
is_prompts: List[bool],
|
||||
selected_seq_groups: List[SequenceGroupToSample],
|
||||
random_samples: torch.Tensor,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
"""Run random sampling on a given samples.
|
||||
|
||||
Args:
|
||||
selected_seq_groups: A list of sequence groups batched.
|
||||
random_samples: (num_selected_samples,) A tensor of samples. The
|
||||
length of samples could be smaller than selected_seq_groups if
|
||||
seq_group.do_sample is False.
|
||||
Returns:
|
||||
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
||||
same as the length of selected_seq_groups. If the corresponding
|
||||
seq_group has do_sample=False, tuple contains ([], [])
|
||||
"""
|
||||
# Find the maximum best_of value of the prompt phase requests.
|
||||
random_samples = random_samples.cpu()
|
||||
sample_idx = 0
|
||||
results = []
|
||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||
seq_ids, sampling_params = seq_group
|
||||
for seq_group in selected_seq_groups:
|
||||
if not seq_group.do_sample:
|
||||
results.append(([], []))
|
||||
continue
|
||||
|
||||
seq_ids = seq_group.seq_ids
|
||||
sampling_params = seq_group.sampling_params
|
||||
is_prompt = seq_group.is_prompt
|
||||
num_parent_seqs = len(seq_ids)
|
||||
if is_prompt:
|
||||
# Prompt phase.
|
||||
|
@ -311,11 +348,20 @@ def _random_sample(
|
|||
|
||||
|
||||
def _beam_search_sample(
|
||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
is_prompts: List[bool],
|
||||
seq_data: Dict[int, SequenceData],
|
||||
selected_seq_groups: List[SequenceGroupToSample],
|
||||
logprobs: torch.Tensor,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
"""Run beam sampling on a given samples.
|
||||
|
||||
Args:
|
||||
selected_seq_groups: A list of sequence groups batched.
|
||||
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
|
||||
on selected sample indices.
|
||||
Returns:
|
||||
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
||||
same as the length of selected_seq_groups. If the corresponding
|
||||
seq_group has do_sample=False, tuple contains ([], [])
|
||||
"""
|
||||
# We sample 2 * beam_width candidates to make sure that with high
|
||||
# probability we can get `beam_width` candidates in addition to
|
||||
# the finished sequences for the next iteration. See
|
||||
|
@ -327,8 +373,13 @@ def _beam_search_sample(
|
|||
# other sampling methods.
|
||||
sample_idx = 0
|
||||
results = []
|
||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||
seq_ids, sampling_params = seq_group
|
||||
for seq_group in selected_seq_groups:
|
||||
if not seq_group.do_sample:
|
||||
results.append(([], []))
|
||||
continue
|
||||
|
||||
is_prompt = seq_group.is_prompt
|
||||
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
|
||||
num_parent_seqs = len(seq_ids)
|
||||
beam_width = sampling_params.best_of
|
||||
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
||||
|
@ -343,7 +394,8 @@ def _beam_search_sample(
|
|||
else:
|
||||
# Generation phase.
|
||||
cumulative_logprobs = [
|
||||
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
|
||||
seq_group.seq_data[seq_id].cumulative_logprob
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
cumulative_logprobs = torch.tensor(
|
||||
cumulative_logprobs,
|
||||
|
@ -371,8 +423,7 @@ def _beam_search_sample(
|
|||
def _multinomial(
|
||||
probs: torch.Tensor,
|
||||
num_samples: int,
|
||||
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
|
||||
generators: Optional[List[torch.Generator]] = None,
|
||||
seq_groups: Optional[List[SequenceGroupToSample]] = None,
|
||||
) -> torch.Tensor:
|
||||
if num_samples > 1:
|
||||
# This is equivalent to torch.repeat_interleaved (which also
|
||||
|
@ -388,9 +439,11 @@ def _multinomial(
|
|||
q.exponential_()
|
||||
else:
|
||||
sample_idx = 0
|
||||
for (seq_ids, _), generator in zip(seq_groups, generators):
|
||||
for seq_group in seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
next_sample_idx = sample_idx + len(seq_ids) * num_samples
|
||||
q[sample_idx:next_sample_idx].exponential_(generator=generator)
|
||||
q[sample_idx:next_sample_idx].exponential_(
|
||||
generator=seq_group.generator)
|
||||
sample_idx = next_sample_idx
|
||||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||
|
||||
|
@ -405,7 +458,7 @@ def _sample_with_torch(
|
|||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
_, sampling_params = seq_group
|
||||
sampling_params = seq_group.sampling_params
|
||||
sampling_type = sampling_params.sampling_type
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
|
||||
|
@ -429,13 +482,11 @@ def _sample_with_torch(
|
|||
num_tokens = len(sample_indices)
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
||||
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
||||
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
||||
is_prompts, sample_indices)
|
||||
long_sample_indices = sample_indices.long()
|
||||
|
||||
seq_group_id = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
||||
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
|
||||
long_sample_indices = sample_indices.long()
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
greedy_samples = torch.argmax(logprobs[long_sample_indices],
|
||||
dim=-1)
|
||||
|
@ -455,14 +506,13 @@ def _sample_with_torch(
|
|||
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
max_best_of_in_batch = 1
|
||||
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
||||
if is_prompt:
|
||||
_, sampling_params = seq_group
|
||||
for seq_group in seq_groups:
|
||||
if seq_group.is_prompt:
|
||||
sampling_params = seq_group.sampling_params
|
||||
max_best_of_in_batch = max(max_best_of_in_batch,
|
||||
sampling_params.best_of)
|
||||
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
|
||||
"seq_groups": seq_groups,
|
||||
"generators": sampling_metadata.generators,
|
||||
}
|
||||
|
||||
multinomial_samples[sampling_type] = _multinomial(
|
||||
|
@ -481,25 +531,22 @@ def _sample_with_torch(
|
|||
|
||||
# GPU<->CPU sync happens in the loop below.
|
||||
# This also converts the sample output to Python objects.
|
||||
|
||||
for sampling_type in SamplingType:
|
||||
if sampling_type not in sample_metadata:
|
||||
continue
|
||||
seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
|
||||
sampling_type]
|
||||
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
sample_results = _random_sample(seq_groups, is_prompts,
|
||||
sample_results = _random_sample(seq_groups,
|
||||
multinomial_samples[sampling_type])
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||
sampling_metadata.seq_data,
|
||||
sample_results = _beam_search_sample(seq_groups,
|
||||
beam_search_logprobs)
|
||||
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
||||
sample_results_dict.update(zip(seq_group_id, sample_results))
|
||||
|
||||
sample_results = [
|
||||
sample_results_dict[i]
|
||||
sample_results_dict.get(i, ([], []))
|
||||
for i in range(len(sampling_metadata.seq_groups))
|
||||
]
|
||||
return sample_results, sampled_token_ids_tensor
|
||||
|
@ -514,7 +561,7 @@ def _sample_with_triton_kernel(
|
|||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
_, sampling_params = seq_group
|
||||
sampling_params = seq_group.sampling_params
|
||||
sampling_type = sampling_params.sampling_type
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
|
||||
|
@ -530,17 +577,16 @@ def _sample_with_triton_kernel(
|
|||
num_tokens = len(sample_indices)
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
||||
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
||||
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
||||
is_prompts, sample_indices,
|
||||
seq_group_id = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
||||
sample_metadata[sampling_type] = (seq_group_id, seq_groups,
|
||||
sample_indices,
|
||||
sampled_token_indices)
|
||||
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
|
||||
SamplingType.RANDOM_SEED):
|
||||
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
||||
if is_prompt:
|
||||
_, sampling_params = seq_group
|
||||
for seq_group in seq_groups:
|
||||
if seq_group.is_prompt:
|
||||
sampling_params = seq_group.sampling_params
|
||||
max_best_of_in_batch = max(max_best_of_in_batch,
|
||||
sampling_params.best_of)
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
|
@ -564,22 +610,21 @@ def _sample_with_triton_kernel(
|
|||
for sampling_type in SamplingType:
|
||||
if sampling_type not in sample_metadata:
|
||||
continue
|
||||
(seq_group_ids, seq_groups, is_prompts, sample_indices,
|
||||
(seq_group_id, seq_groups, sample_indices,
|
||||
sampled_token_indices) = sample_metadata[sampling_type]
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
sample_results = _greedy_sample(
|
||||
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
sample_results = _random_sample(
|
||||
seq_groups, is_prompts, sampled_tokens[sampled_token_indices])
|
||||
seq_groups, sampled_tokens[sampled_token_indices])
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||
sampling_metadata.seq_data,
|
||||
sample_results = _beam_search_sample(seq_groups,
|
||||
beam_search_logprobs)
|
||||
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
||||
sample_results_dict.update(zip(seq_group_id, sample_results))
|
||||
|
||||
sample_results = [
|
||||
sample_results_dict[i]
|
||||
sample_results_dict.get(i, ([], []))
|
||||
for i in range(len(sampling_metadata.seq_groups))
|
||||
]
|
||||
return sample_results
|
||||
|
@ -590,6 +635,18 @@ def _sample(
|
|||
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
||||
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
||||
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
probs: (num_query_tokens_in_batch, num_vocab)
|
||||
logprobs: (num_query_tokens_in_batch, num_vocab)
|
||||
sampling_metadata: The metadata for a batch for sampling.
|
||||
sampling_tensors: Tensors that include sampling related metadata.
|
||||
|
||||
Returns:
|
||||
(next_token_ids, parent_seq_ids) for each seq group in a batch.
|
||||
If sampling is skipped, it returns ([], [])
|
||||
sampled_token_ids_tensor: A tensor of sampled token ids.
|
||||
"""
|
||||
return _sample_with_torch(
|
||||
probs,
|
||||
logprobs,
|
||||
|
@ -626,56 +683,97 @@ def _get_logprobs(
|
|||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sample_results: List[Tuple[List[int], List[int]]],
|
||||
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
|
||||
int, float]]]]:
|
||||
# Prepare query indices
|
||||
batched_logprobs_query_seq_indices: List[int] = []
|
||||
batched_logprobs_query_token_indices: List[int] = []
|
||||
# at least get one logprob for each token
|
||||
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
|
||||
"""Return sample lobprobs and prompt logprobs.
|
||||
|
||||
The logic consists of 3 parts.
|
||||
- Select indices to compute logprob from, ranks of token ids, and
|
||||
the top k token ids from logprobs.
|
||||
- Compute prompt logprobs if required.
|
||||
- Compute sample logprobs if required.
|
||||
|
||||
Args:
|
||||
logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
|
||||
logprob per vocab. Sequence groups' query tokens are batched in a
|
||||
single flattened tensor. For example, assuming there are N
|
||||
seq groups, it is sorted by prefill tokens for seq_group_1 (if
|
||||
prompt logprob is enabled), decode tokens for seq_group_1 (if
|
||||
sampling is required), prefill tokens for seq_group_2, ...
|
||||
sampling_metadata: The sampling metadata.
|
||||
sample_results: (num_seq_groups) The tuple of (next_token_ids,
|
||||
parent_ids) for each sequence group. When beam search is enabled,
|
||||
sample_results can contain different number of seq_ids from
|
||||
sampling_metadata.seq_groups. It is because beam search creates
|
||||
2 * BEAM_WIDTH number of samples (whereas there are only up to
|
||||
BEAM_WIDTH number of seq_ids).
|
||||
|
||||
Returns:
|
||||
A tuple of prompt and sample logprobs per sequence group in a batch.
|
||||
"""
|
||||
# The index of query token to calculate logprobs. It includes both
|
||||
# prompt and sample logprob indices.
|
||||
query_indices: List[int] = []
|
||||
# The next token ids to get the logprob value from.
|
||||
next_token_ids: List[int] = []
|
||||
# The largest requested number of logprobs. We find logprobs as many as the
|
||||
# largest num logprobs in this API.
|
||||
largest_num_logprobs = 1
|
||||
sample_idx = 0
|
||||
for i, (seq_group, sample_result) in enumerate(
|
||||
zip(sampling_metadata.seq_groups, sample_results)):
|
||||
seq_ids, sampling_params = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
num_parent_seqs = len(seq_ids)
|
||||
if (i < sampling_metadata.num_prompts
|
||||
|
||||
# Select indices to compute logprob from, ranks of token ids, and the top
|
||||
# k token ids from logprobs.
|
||||
for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
|
||||
sample_results):
|
||||
sampling_params = seq_group.sampling_params
|
||||
|
||||
# Update indices and tokens for prompt logprobs.
|
||||
if (seq_group.is_prompt
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
largest_num_logprobs = max(largest_num_logprobs,
|
||||
sampling_params.prompt_logprobs)
|
||||
prompt_len = sampling_metadata.prompt_lens[i]
|
||||
prompt_tokens = sampling_metadata.seq_data[
|
||||
seq_ids[0]].prompt_token_ids
|
||||
batched_logprobs_query_seq_indices.extend(
|
||||
sample_idx + j for j in range(prompt_len - 1))
|
||||
batched_logprobs_query_token_indices.extend(
|
||||
token_id for token_id in prompt_tokens[1:])
|
||||
sample_idx += prompt_len - 1
|
||||
batched_logprobs_query_seq_indices.extend(
|
||||
[sample_idx + parent_id for parent_id in parent_ids])
|
||||
batched_logprobs_query_token_indices.extend(next_token_ids)
|
||||
if sampling_params.logprobs is not None:
|
||||
largest_num_logprobs = max(largest_num_logprobs,
|
||||
sampling_params.logprobs)
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == logprobs.size(0)
|
||||
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
|
||||
query_indices.extend(seq_group.prompt_logprob_indices)
|
||||
next_token_ids.extend(next_prompt_tokens)
|
||||
|
||||
batched_logprobs_query_seq_indices_gpu = torch.tensor(
|
||||
batched_logprobs_query_seq_indices, device=logprobs.device)
|
||||
batched_logprobs_query_token_indices_gpu = torch.tensor(
|
||||
batched_logprobs_query_token_indices, device=logprobs.device)
|
||||
# Update indices and next tokenes for sample logprob.
|
||||
if seq_group.do_sample:
|
||||
token_ids, parent_seq_ids = sample_result
|
||||
# NOTE: We cannot directly use sample_indices because
|
||||
# sample_indices only contain parent seq_ids of a previous step.
|
||||
# The current step may have different number of seq_ids, and
|
||||
# we can obtain it from `sample_result[1]`.
|
||||
query_idx = seq_group.sample_indices[0]
|
||||
query_indices.extend(
|
||||
[query_idx + parent_id for parent_id in parent_seq_ids])
|
||||
next_token_ids.extend(token_ids)
|
||||
|
||||
# Batched query for logprobs of selected token
|
||||
batched_logprobs_query_result = logprobs[[
|
||||
batched_logprobs_query_seq_indices_gpu,
|
||||
batched_logprobs_query_token_indices_gpu
|
||||
if sampling_params.logprobs is not None:
|
||||
largest_num_logprobs = max(largest_num_logprobs,
|
||||
sampling_params.logprobs)
|
||||
|
||||
assert len(next_token_ids) == len(query_indices)
|
||||
|
||||
if len(query_indices) == 0:
|
||||
empty_sampled_logprob = []
|
||||
empty_prompt_logprob = None
|
||||
return [empty_prompt_logprob], [empty_sampled_logprob]
|
||||
|
||||
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
|
||||
next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
|
||||
|
||||
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
|
||||
# contain duplicates if beam search is enabled.
|
||||
selected_logprobs = logprobs[[
|
||||
query_indices_gpu,
|
||||
next_token_ids_gpu,
|
||||
]]
|
||||
ranks = _get_ranks(
|
||||
logprobs[query_indices_gpu],
|
||||
next_token_ids_gpu,
|
||||
)
|
||||
assert selected_logprobs.shape[0] == ranks.shape[0]
|
||||
|
||||
batched_ranks_query_result = _get_ranks(
|
||||
logprobs[batched_logprobs_query_seq_indices_gpu],
|
||||
batched_logprobs_query_token_indices_gpu)
|
||||
|
||||
# Batched query for logprobs of topk tokens
|
||||
# Logprobs of topk tokens for a batch of sequence groups.
|
||||
# (num_query_tokens_across_batch).
|
||||
if largest_num_logprobs > 0:
|
||||
top_logprobs, top_token_ids = torch.topk(logprobs,
|
||||
largest_num_logprobs,
|
||||
|
@ -685,79 +783,136 @@ def _get_logprobs(
|
|||
else:
|
||||
top_logprobs, top_token_ids = None, None
|
||||
|
||||
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
|
||||
batched_ranks_query_result = batched_ranks_query_result.cpu()
|
||||
selected_logprobs = selected_logprobs.cpu()
|
||||
ranks = ranks.cpu()
|
||||
|
||||
# Gather results
|
||||
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
|
||||
result_sample_logprobs: List[SampleLogprobs] = []
|
||||
sample_idx = 0
|
||||
query_result_idx = 0
|
||||
for i, (seq_group, sample_result) in enumerate(
|
||||
zip(sampling_metadata.seq_groups, sample_results)):
|
||||
seq_ids, sampling_params = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
# Find prompt/sample logprobs.
|
||||
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
|
||||
sample_logprobs_per_seq_group: List[SampleLogprobs] = []
|
||||
top_logprob_idx = 0
|
||||
selected_logprobs_idx = 0
|
||||
|
||||
# Prompt logprobs
|
||||
if (i < sampling_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
num_logprobs = sampling_params.prompt_logprobs
|
||||
prompt_tokens = sampling_metadata.seq_data[
|
||||
seq_ids[0]].prompt_token_ids
|
||||
group_prompt_logprobs: PromptLogprobs = [None]
|
||||
for token_id in prompt_tokens[1:]:
|
||||
prompt_logprobs_dict = {
|
||||
token_id:
|
||||
(batched_logprobs_query_result[query_result_idx].item(),
|
||||
batched_ranks_query_result[query_result_idx].item())
|
||||
}
|
||||
if num_logprobs > 0:
|
||||
prompt_logprobs_dict.update(
|
||||
zip(
|
||||
top_token_ids[sample_idx, :num_logprobs].tolist(),
|
||||
zip(
|
||||
top_logprobs[
|
||||
sample_idx, :num_logprobs].tolist(),
|
||||
range(1, num_logprobs + 1))))
|
||||
group_prompt_logprobs.append({
|
||||
token_id: Logprob(*logprob_rank)
|
||||
for token_id, logprob_rank in prompt_logprobs_dict.items()
|
||||
})
|
||||
sample_idx += 1
|
||||
query_result_idx += 1
|
||||
result_prompt_logprobs.append(group_prompt_logprobs)
|
||||
else:
|
||||
result_prompt_logprobs.append(None)
|
||||
for seq_group, sample_result in zip(sampling_metadata.seq_groups,
|
||||
sample_results):
|
||||
(prompt_logprobs, top_logprob_idx,
|
||||
selected_logprobs_idx) = _get_prompt_logprob_if_needed(
|
||||
seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
|
||||
selected_logprobs_idx, top_logprob_idx)
|
||||
prompt_logprobs_per_seq_group.append(prompt_logprobs)
|
||||
|
||||
# Sample logprobs
|
||||
num_logprobs = sampling_params.logprobs
|
||||
if num_logprobs is None:
|
||||
num_logprobs = 0
|
||||
group_sample_logprobs: SampleLogprobs = []
|
||||
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
|
||||
sample_logprobs_dict = {
|
||||
next_token_id:
|
||||
(batched_logprobs_query_result[query_result_idx].item(),
|
||||
batched_ranks_query_result[query_result_idx].item())
|
||||
(sampled_logprobs, top_logprob_idx,
|
||||
selected_logprobs_idx) = _get_sampled_logprob_if_needed(
|
||||
seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
|
||||
top_logprobs, selected_logprobs_idx, top_logprob_idx)
|
||||
sample_logprobs_per_seq_group.append(sampled_logprobs)
|
||||
|
||||
return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
|
||||
|
||||
|
||||
def _get_prompt_logprob_if_needed(
|
||||
seq_group: SequenceGroupToSample,
|
||||
selected_logprobs: torch.Tensor,
|
||||
ranks: torch.Tensor,
|
||||
top_token_ids: torch.Tensor,
|
||||
top_logprobs: torch.Tensor,
|
||||
selected_logprobs_idx: int,
|
||||
top_logprob_idx: int,
|
||||
):
|
||||
"""Compute the prompt logprob from a sequence group if needed."""
|
||||
sampling_params = seq_group.sampling_params
|
||||
is_prompt = seq_group.is_prompt
|
||||
|
||||
# Find prompt logprobs
|
||||
prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
if (is_prompt and sampling_params.prompt_logprobs is not None):
|
||||
prompt_logprobs = []
|
||||
num_logprobs = sampling_params.prompt_logprobs
|
||||
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
|
||||
for token_id in next_prompt_tokens:
|
||||
# Calculate the prompt logprob of the real prompt tokens.
|
||||
# Use tuple here for performance (to use to_list()).
|
||||
# {token_id: (logprob, rank_from_vocab)}
|
||||
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
|
||||
token_id: (selected_logprobs[selected_logprobs_idx].item(),
|
||||
ranks[selected_logprobs_idx].item())
|
||||
}
|
||||
query_result_idx += 1
|
||||
if num_logprobs >= 0:
|
||||
sample_logprobs_dict.update(
|
||||
|
||||
# Add top K prompt logprobs along with its rank.
|
||||
if num_logprobs > 0:
|
||||
prompt_logprobs_dict.update(
|
||||
zip(
|
||||
top_token_ids[sample_idx +
|
||||
top_token_ids[top_logprob_idx, :num_logprobs].tolist(),
|
||||
zip(
|
||||
top_logprobs[
|
||||
top_logprob_idx, :num_logprobs].tolist(),
|
||||
# This is ranks. Since top_logprob is sorted,
|
||||
# we can just use a range here.
|
||||
range(1, num_logprobs + 1))))
|
||||
prompt_logprobs.append({
|
||||
token_id: Logprob(*logprob_and_rank)
|
||||
for token_id, logprob_and_rank in prompt_logprobs_dict.items()
|
||||
})
|
||||
# + 1 to go to the next prompt token.
|
||||
top_logprob_idx += 1
|
||||
selected_logprobs_idx += 1
|
||||
return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
|
||||
|
||||
|
||||
def _get_sampled_logprob_if_needed(
|
||||
seq_group: SequenceGroupToSample,
|
||||
sample_result: Tuple[List[int], List[int]],
|
||||
selected_logprobs: torch.Tensor,
|
||||
ranks: torch.Tensor,
|
||||
top_token_ids: torch.Tensor,
|
||||
top_logprobs: torch.Tensor,
|
||||
selected_logprobs_idx: int,
|
||||
top_logprob_idx: int,
|
||||
):
|
||||
"""Compute the sample logprob if needed."""
|
||||
seq_ids = seq_group.seq_ids
|
||||
num_logprobs = seq_group.sampling_params.logprobs
|
||||
if num_logprobs is None:
|
||||
num_logprobs = 0
|
||||
sampled_logprobs: SampleLogprobs = []
|
||||
next_token_ids, parent_seq_ids = sample_result
|
||||
|
||||
if seq_group.do_sample:
|
||||
assert len(next_token_ids) > 0
|
||||
for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids):
|
||||
# Calculate the sample logprob of the real sampled tokens.
|
||||
# Use tuple here for performance (to use to_list()).
|
||||
# token_id: (logprob, rank_from_vocab)
|
||||
sampled_logprobs_dict: Dict[int, Tuple[float, int]] = {
|
||||
next_token_id:
|
||||
(selected_logprobs[selected_logprobs_idx].item(),
|
||||
ranks[selected_logprobs_idx].item())
|
||||
}
|
||||
# +1 to go to the next sampled token. Note that
|
||||
# selected_logprobs can contain duplicates unlike top_logprobs
|
||||
# when beam search is enabled.
|
||||
selected_logprobs_idx += 1
|
||||
|
||||
# Second, add top K logprobs along with its rank.
|
||||
if num_logprobs >= 0:
|
||||
sampled_logprobs_dict.update(
|
||||
zip(
|
||||
top_token_ids[top_logprob_idx +
|
||||
parent_id, :num_logprobs].tolist(),
|
||||
zip(
|
||||
top_logprobs[sample_idx +
|
||||
top_logprobs[top_logprob_idx +
|
||||
parent_id, :num_logprobs].tolist(),
|
||||
# This is rank. Since top_logprob is sorted, we
|
||||
# can just use a range here.
|
||||
range(1, num_logprobs + 1))))
|
||||
group_sample_logprobs.append({
|
||||
token_id: Logprob(*logprob_rank)
|
||||
for token_id, logprob_rank in sample_logprobs_dict.items()
|
||||
sampled_logprobs.append({
|
||||
token_id: Logprob(*logprob_and_rank)
|
||||
for token_id, logprob_and_rank in
|
||||
sampled_logprobs_dict.items()
|
||||
})
|
||||
result_sample_logprobs.append(group_sample_logprobs)
|
||||
sample_idx += len(seq_ids)
|
||||
|
||||
return result_prompt_logprobs, result_sample_logprobs
|
||||
# There are len(seq_ids) number of sampled tokens for the current
|
||||
# sequence group in top_logprobs. Jump to the next seq_group.
|
||||
top_logprob_idx += len(seq_ids)
|
||||
return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
|
||||
|
||||
|
||||
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
||||
|
@ -832,7 +987,7 @@ def _build_sampler_output(
|
|||
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
||||
sample_results, prompt_logprobs,
|
||||
sample_logprobs):
|
||||
seq_ids, _ = seq_group
|
||||
seq_ids = seq_group.seq_ids
|
||||
next_token_ids, parent_ids = sample_result
|
||||
seq_outputs = []
|
||||
for parent_id, next_token_id, logprobs in zip(parent_ids,
|
||||
|
@ -854,3 +1009,36 @@ def _build_sampler_output(
|
|||
sampled_token_probs=sampled_token_probs,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
)
|
||||
|
||||
|
||||
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
|
||||
"""Get a list of next prompt tokens to compute logprob from a
|
||||
given sequence group.
|
||||
|
||||
It is used to compute prompt logprob. Imagine you have logprob for each
|
||||
query token. Query token needs to know the next prompt token id to compute
|
||||
prompt logprob. This is a helper to obtain next prompt token ids.
|
||||
|
||||
This API has to be used only when the caller knows seq_group is in prefill
|
||||
stage.
|
||||
|
||||
Returns:
|
||||
A list of next prompt tokens to compute logprob.
|
||||
"""
|
||||
assert seq_group.is_prompt, (
|
||||
"Caller should ensure the sequence group is in a prefill stage.")
|
||||
seq_ids = seq_group.seq_ids
|
||||
subquery_len = seq_group.subquery_len
|
||||
assert subquery_len is not None
|
||||
# prompt has only 1 seq id.
|
||||
assert len(seq_ids) == 1
|
||||
seq_data = seq_group.seq_data[seq_ids[0]]
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
prompt_tokens = seq_data.prompt_token_ids
|
||||
# +1 because we are looking for a next prompt token.
|
||||
next_token_index_start = computed_len + 1
|
||||
next_token_index_end = min(computed_len + subquery_len + 1,
|
||||
len(prompt_tokens))
|
||||
next_prompt_tokens = prompt_tokens[
|
||||
next_token_index_start:next_token_index_end]
|
||||
return next_prompt_tokens
|
||||
|
|
|
@ -6,57 +6,275 @@ import torch
|
|||
|
||||
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SequenceData
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.sequence import SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
|
||||
maybe_expand_dim)
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
_SEED_0_REPLACEMENT = 3403598558
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceGroupToSample:
|
||||
# Sequence ids for the sequence group in a previous step.
|
||||
seq_ids: List[int]
|
||||
sampling_params: SamplingParams
|
||||
# seq_id -> sequence data.
|
||||
seq_data: Dict[int, SequenceData]
|
||||
# The length of the prompt of the sequence group. None if it is in a decode
|
||||
# stage.
|
||||
prompt_len: Optional[int]
|
||||
# The length of the query tokens to compute in the current step. None if it
|
||||
# is in a decode stage. The length of subquery_len <= prompt_len.
|
||||
subquery_len: Optional[int]
|
||||
# A random number generator for sampling.
|
||||
generator: Optional[torch.Generator]
|
||||
# True if the sequence group is in prefill stage. False if it is in a
|
||||
# decode stage.
|
||||
is_prompt: bool
|
||||
# Query token indices from logits. to compute prompt logprob. Empty if
|
||||
# prompt logprob is not required.
|
||||
prompt_logprob_indices: List[int]
|
||||
# Sample token indices from logits. Empty if sampling is not required.
|
||||
sample_indices: List[int]
|
||||
|
||||
@property
|
||||
def do_sample(self):
|
||||
return len(self.sample_indices) > 0
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.prompt_logprob_indices) > 0:
|
||||
assert self.sampling_params.prompt_logprobs is not None
|
||||
if self.is_prompt:
|
||||
assert self.prompt_len is not None
|
||||
assert self.subquery_len is not None
|
||||
|
||||
|
||||
class SamplingMetadata:
|
||||
"""Metadata for input sequences. Used in sampler.
|
||||
|
||||
The usage is as follow;
|
||||
```
|
||||
hidden_states = execute_model(...)
|
||||
logits = hidden_states[sampling_metadata.selected_token_indices]
|
||||
sample(logits)
|
||||
|
||||
def sample(logits):
|
||||
# Use categorized_sample_indices for sampling....
|
||||
```
|
||||
|
||||
Args:
|
||||
seq_groups: List of (seq_ids, sampling_params).
|
||||
seq_data: Seq_id -> SequenceData.
|
||||
prompt_lens: Lengths of prompts.
|
||||
selected_token_indices: Token indices selected for sampling.
|
||||
seq_groups: List of batched sequence groups.
|
||||
selected_token_indices: (num_query_tokens_to_logprob). Indices to find
|
||||
logits from the initial model output hidden states.
|
||||
categorized_sample_indices: SamplingType -> token indices to sample.
|
||||
generators: List of torch.Generators to use for seeded sampling
|
||||
perform_sampling: Whether to perform sampling. This option is used to
|
||||
make the sampling only happens in the driver worker, and disable
|
||||
sampling in other worker processes.
|
||||
Each token indices is 2D tensor of (num_indices, num_indices) where
|
||||
the first item means the sample index within the returned logit
|
||||
(before pruning padding), and the second item means the sample
|
||||
index after pruning using selected_token_indices.
|
||||
For example, if the returned logit is [1, 2, 3], and we select
|
||||
[1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
|
||||
The first tuple is [1, 2] (sampled index within original logit),
|
||||
and the second tuple is [0, 1] (sampled index within pruned logit).
|
||||
num_prompts: Number of prompt sequence groups in seq_groups.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
|
||||
seq_data: Optional[Dict[int, SequenceData]],
|
||||
prompt_lens: Optional[List[int]],
|
||||
seq_groups: List[SequenceGroupToSample],
|
||||
selected_token_indices: torch.Tensor,
|
||||
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
|
||||
generators: Optional[List[torch.Generator]] = None,
|
||||
perform_sampling: bool = True,
|
||||
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
||||
num_prompts: int,
|
||||
) -> None:
|
||||
self.seq_groups = seq_groups
|
||||
self.seq_data = seq_data
|
||||
self.prompt_lens = prompt_lens
|
||||
self.selected_token_indices = selected_token_indices
|
||||
self.categorized_sample_indices = categorized_sample_indices
|
||||
self.generators = generators
|
||||
self.perform_sampling = perform_sampling
|
||||
self.num_prompts = num_prompts
|
||||
|
||||
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
|
||||
@staticmethod
|
||||
def prepare(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
subquery_lens: Optional[List[int]],
|
||||
device: str,
|
||||
pin_memory: bool,
|
||||
) -> "SamplingMetadata":
|
||||
(
|
||||
seq_groups,
|
||||
selected_token_indices,
|
||||
categorized_sample_indices,
|
||||
num_prompts,
|
||||
) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens,
|
||||
subquery_lens, device)
|
||||
selected_token_indices = async_tensor_h2d(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
target_device=device,
|
||||
pin_memory=pin_memory)
|
||||
categorized_sample_indices = {
|
||||
t: maybe_expand_dim(
|
||||
async_tensor_h2d(seq_ids,
|
||||
dtype=torch.int,
|
||||
target_device=device,
|
||||
pin_memory=pin_memory), 2, 2)
|
||||
for t, seq_ids in categorized_sample_indices.items()
|
||||
}
|
||||
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=seq_groups,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=categorized_sample_indices,
|
||||
num_prompts=num_prompts,
|
||||
)
|
||||
return sampling_metadata
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
"SamplingMetadata("
|
||||
f"seq_groups={self.seq_groups}, "
|
||||
f"seq_data={self.seq_data}, "
|
||||
f"prompt_lens={self.prompt_lens}, "
|
||||
f"selected_token_indices={self.selected_token_indices}, "
|
||||
f"categorized_sample_indices={self.categorized_sample_indices}), "
|
||||
f"perform_sampling={self.perform_sampling})")
|
||||
f"categorized_sample_indices={self.categorized_sample_indices}), ")
|
||||
|
||||
|
||||
def _prepare_seq_groups(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
subquery_lens: Optional[List[int]],
|
||||
device: str,
|
||||
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
|
||||
SamplingType, List[Tuple[int, int]]], int]:
|
||||
"""Prepare sequence groups and indices for sampling.
|
||||
|
||||
Args:
|
||||
seq_group_metadata_list: A list of sequence group to batch.
|
||||
prompt_lens: A list of prompt lens per sequence group.
|
||||
Index of prompt len should match with seq_group_metadata_list.
|
||||
subquery_lens: A list of query lengths. Prompt lens include the length
|
||||
of entire prompt tokens, and it could be shorter.
|
||||
device: A device to use for random number generator,
|
||||
`SequenceGroupToSample.generator`.
|
||||
|
||||
Returns:
|
||||
seq_groups: A list of sequence group to sample.
|
||||
selected_token_indices: See the definition from `SamplingMetadata`.
|
||||
categorized_sample_indices: See the definition from `SamplingMetadata`.
|
||||
num_prompts: Total number of prompts from `seq_group_metadata_list`.
|
||||
"""
|
||||
# Batched sequence groups for the current model forward stsep.
|
||||
seq_groups: List[SequenceGroupToSample] = []
|
||||
# A list of token indices to sample/compute logprob. It is used to
|
||||
# prune the outcome logits from the model for the performance.
|
||||
selected_token_indices: List[int] = []
|
||||
# Used for selected_token_indices.
|
||||
model_output_idx = 0
|
||||
|
||||
# Sampling type -> (
|
||||
# indices to sample/prompt logprob within pruned output logits,
|
||||
# indices to sample within pruned logits)
|
||||
categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
|
||||
t: []
|
||||
for t in SamplingType
|
||||
}
|
||||
# Index of logits to compute logprob. Logits include both prompt logprob
|
||||
# and sample logprob indices.
|
||||
logit_idx = 0
|
||||
# Index to sample from a sample tensor. It is used by triton sample kernel.
|
||||
# See `_sample_with_triton_kernel` for more details.
|
||||
sample_idx = 0
|
||||
# Total number of prompts from given sequence groups.
|
||||
num_prompts = 0
|
||||
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
generator: Optional[torch.Generator] = None
|
||||
# If the current seq group is in decode stage, it is None.
|
||||
prompt_len: Optional[int] = None
|
||||
subquery_len: Optional[int] = None
|
||||
prompt_logprob_indices: List[int] = []
|
||||
sample_indices: List[int] = []
|
||||
do_sample = seq_group_metadata.do_sample
|
||||
|
||||
if seq_group_metadata.is_prompt:
|
||||
if sampling_params.seed is not None:
|
||||
seq_group_metadata.state.generator = torch.Generator(
|
||||
device=device).manual_seed(sampling_params.seed)
|
||||
|
||||
num_prompts += 1
|
||||
num_prefill_sample = len(seq_ids)
|
||||
assert num_prefill_sample == 1
|
||||
assert subquery_lens is not None and prompt_lens is not None
|
||||
subquery_len, prompt_len = subquery_lens[i], prompt_lens[i]
|
||||
# If we need sampling, exclude num_prefill_sample tokens from
|
||||
# prompt logprob.
|
||||
prompt_logprob_len = (subquery_len - num_prefill_sample
|
||||
if do_sample else subquery_len)
|
||||
sample_len = num_prefill_sample if do_sample else 0
|
||||
else:
|
||||
# Decode
|
||||
prompt_logprob_len = 0
|
||||
sample_len = len(seq_ids) if do_sample else 0
|
||||
|
||||
# Update indices to select from the model output.
|
||||
"""
|
||||
This blocks computes selected_token_indices which is used in the
|
||||
following way.
|
||||
|
||||
hidden_states = model(...)
|
||||
logits = hidden_states[selected_token_indices]
|
||||
"""
|
||||
|
||||
if sampling_params.prompt_logprobs:
|
||||
selected_token_indices.extend(
|
||||
range(model_output_idx, model_output_idx + prompt_logprob_len))
|
||||
model_output_idx += prompt_logprob_len
|
||||
if do_sample:
|
||||
selected_token_indices.extend(
|
||||
range(model_output_idx, model_output_idx + sample_len))
|
||||
model_output_idx += sample_len
|
||||
|
||||
# We now find indices for logprob computation and sampling.
|
||||
"""
|
||||
This block computes categorized_sample_indices which is used in the
|
||||
following way.
|
||||
|
||||
hidden_states = model(...)
|
||||
logits = hidden_states[selected_token_indices]
|
||||
def sample(logits):
|
||||
# Use categorized_sample_indices for sampling.
|
||||
# prompt_logprob_indices to find prompt logprob indices.
|
||||
# sample_indices to find sample indices.
|
||||
"""
|
||||
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
prompt_logprob_indices.extend(
|
||||
range(logit_idx, logit_idx + prompt_logprob_len))
|
||||
logit_idx += prompt_logprob_len
|
||||
if do_sample:
|
||||
sample_indices.extend(range(logit_idx, logit_idx + sample_len))
|
||||
categorized_sample_indices[sampling_params.sampling_type].extend(
|
||||
list(
|
||||
zip(range(logit_idx, logit_idx + sample_len),
|
||||
range(sample_idx, sample_idx + sample_len))))
|
||||
logit_idx += sample_len
|
||||
sample_idx += sample_len
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
generator = seq_group_metadata.state.generator
|
||||
|
||||
seq_groups.append(
|
||||
SequenceGroupToSample(
|
||||
seq_ids=seq_ids,
|
||||
sampling_params=sampling_params,
|
||||
seq_data=seq_group_metadata.seq_data,
|
||||
prompt_len=prompt_len,
|
||||
subquery_len=subquery_len,
|
||||
generator=generator,
|
||||
is_prompt=is_prompt,
|
||||
prompt_logprob_indices=list(prompt_logprob_indices),
|
||||
sample_indices=list(sample_indices)))
|
||||
return (seq_groups, selected_token_indices, categorized_sample_indices,
|
||||
num_prompts)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -112,11 +330,10 @@ class SamplingTensors:
|
|||
seeds_to_generate = (extra_seeds_to_generate +
|
||||
get_num_triton_sampler_splits(vocab_size))
|
||||
|
||||
sample_indices_start_idx = 0
|
||||
assert sampling_metadata.seq_groups is not None
|
||||
assert sampling_metadata.seq_data is not None
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
sampling_params = seq_group.sampling_params
|
||||
temperature = sampling_params.temperature
|
||||
p = sampling_params.presence_penalty
|
||||
f = sampling_params.frequency_penalty
|
||||
|
@ -145,45 +362,46 @@ class SamplingTensors:
|
|||
or abs(r - 1.0) >= _SAMPLING_EPS):
|
||||
do_penalties = True
|
||||
|
||||
if (i < sampling_metadata.num_prompts
|
||||
is_prompt = seq_group.is_prompt
|
||||
if (seq_group.is_prompt
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
# For tokens in the prompt that we only need to get
|
||||
# their logprobs
|
||||
assert sampling_metadata.prompt_lens is not None
|
||||
prompt_len = sampling_metadata.prompt_lens[i]
|
||||
temperatures += [temperature] * (prompt_len - 1)
|
||||
top_ps += [top_p] * (prompt_len - 1)
|
||||
top_ks += [top_k] * (prompt_len - 1)
|
||||
min_ps += [min_p] * (prompt_len - 1)
|
||||
presence_penalties += [0] * (prompt_len - 1)
|
||||
frequency_penalties += [0] * (prompt_len - 1)
|
||||
repetition_penalties += [1] * (prompt_len - 1)
|
||||
prompt_tokens.extend([] for _ in range(prompt_len - 1))
|
||||
output_tokens.extend([] for _ in range(prompt_len - 1))
|
||||
for seq_id in seq_ids:
|
||||
seq_data = sampling_metadata.seq_data[seq_id]
|
||||
prompt_tokens.append(seq_data.prompt_token_ids)
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
temperatures += [temperature] * len(seq_ids)
|
||||
top_ps += [top_p] * len(seq_ids)
|
||||
top_ks += [top_k] * len(seq_ids)
|
||||
min_ps += [min_p] * len(seq_ids)
|
||||
presence_penalties += [p] * len(seq_ids)
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
repetition_penalties += [r] * len(seq_ids)
|
||||
subquery_len = seq_group.subquery_len
|
||||
assert subquery_len is not None
|
||||
prefill_len = len(seq_group.prompt_logprob_indices)
|
||||
temperatures += [temperature] * prefill_len
|
||||
top_ps += [top_p] * prefill_len
|
||||
top_ks += [top_k] * prefill_len
|
||||
min_ps += [min_p] * prefill_len
|
||||
presence_penalties += [0] * prefill_len
|
||||
frequency_penalties += [0] * prefill_len
|
||||
repetition_penalties += [1] * prefill_len
|
||||
prompt_tokens.extend([] for _ in range(prefill_len))
|
||||
output_tokens.extend([] for _ in range(prefill_len))
|
||||
|
||||
if seq_group.do_sample:
|
||||
sample_lens = len(seq_group.sample_indices)
|
||||
assert sample_lens == len(seq_ids)
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group.seq_data[seq_id]
|
||||
prompt_tokens.append(seq_data.prompt_token_ids)
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
temperatures += [temperature] * len(seq_ids)
|
||||
top_ps += [top_p] * len(seq_ids)
|
||||
top_ks += [top_k] * len(seq_ids)
|
||||
min_ps += [min_p] * len(seq_ids)
|
||||
presence_penalties += [p] * len(seq_ids)
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
repetition_penalties += [r] * len(seq_ids)
|
||||
|
||||
is_prompt = i < sampling_metadata.num_prompts
|
||||
if is_prompt:
|
||||
prompt_best_of.append(sampling_params.best_of)
|
||||
assert sampling_metadata.prompt_lens is not None
|
||||
prompt_len = sampling_metadata.prompt_lens[i]
|
||||
subquery_len = seq_group.subquery_len
|
||||
assert subquery_len is not None
|
||||
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
# NOTE: the sampling position is the last token
|
||||
# in the prompt
|
||||
sample_indices_start_idx += prompt_len - 1
|
||||
for seq_id in seq_ids:
|
||||
seq_data = sampling_metadata.seq_data[seq_id]
|
||||
seq_data = seq_group.seq_data[seq_id]
|
||||
extra_entropy = extra_entropy or ()
|
||||
seq_seeds = cls._get_sequence_seeds(
|
||||
seed,
|
||||
|
@ -193,8 +411,7 @@ class SamplingTensors:
|
|||
seeds_to_generate=seeds_to_generate,
|
||||
is_greedy=is_greedy)
|
||||
sampling_seeds.append(seq_seeds)
|
||||
sample_indices.append(sample_indices_start_idx)
|
||||
sample_indices_start_idx += 1
|
||||
sample_indices.extend(seq_group.sample_indices)
|
||||
|
||||
sampling_tensors = SamplingTensors.from_lists(
|
||||
temperatures, top_ps, top_ks, min_ps, presence_penalties,
|
||||
|
@ -217,12 +434,14 @@ class SamplingTensors:
|
|||
# Note that the performance will be very bad without
|
||||
# pinned memory.
|
||||
pin_memory = is_pin_memory_available()
|
||||
prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
|
||||
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
|
||||
default=0)
|
||||
prompt_padded_tokens = [
|
||||
tokens + [vocab_size] * (prompt_max_len - len(tokens))
|
||||
for tokens in prompt_tokens
|
||||
]
|
||||
output_max_len = max(len(tokens) for tokens in output_tokens)
|
||||
output_max_len = max([len(tokens) for tokens in output_tokens],
|
||||
default=0)
|
||||
output_padded_tokens = [
|
||||
tokens + [vocab_size] * (output_max_len - len(tokens))
|
||||
for tokens in output_tokens
|
||||
|
|
|
@ -28,7 +28,10 @@ class Logprob:
|
|||
decoded_token: Optional[str] = None
|
||||
|
||||
|
||||
# {token_id -> logprob} per each sequence group. None if the corresponding
|
||||
# sequence group doesn't require prompt logprob.
|
||||
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
|
||||
# {token_id -> logprob} for each sequence group.
|
||||
SampleLogprobs = List[Dict[int, Logprob]]
|
||||
|
||||
|
||||
|
@ -215,7 +218,7 @@ class Sequence:
|
|||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
|
||||
self.data = SequenceData(prompt_token_ids)
|
||||
self.data: SequenceData = SequenceData(prompt_token_ids)
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
self.output_text = ""
|
||||
|
||||
|
@ -559,6 +562,9 @@ class SequenceGroupMetadata:
|
|||
sampling_params: The sampling parameters used to generate the outputs.
|
||||
block_tables: The block tables. (Seq id -> list of physical block
|
||||
numbers)
|
||||
do_sample: True if sampling is required. Sampling is not required when
|
||||
e.g., prefill is chunked, and the current iteration only computes
|
||||
query tokens for prefill, we don't need sampling.
|
||||
token_chunk_size: The number of tokens to be processed (per sequence).
|
||||
None if chunking is not required.
|
||||
state: Internal state tied to this sequence group.
|
||||
|
@ -573,6 +579,7 @@ class SequenceGroupMetadata:
|
|||
seq_data: Dict[int, SequenceData],
|
||||
sampling_params: SamplingParams,
|
||||
block_tables: Dict[int, List[int]],
|
||||
do_sample: bool = True,
|
||||
token_chunk_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
computed_block_nums: Optional[List[int]] = None,
|
||||
|
@ -589,6 +596,7 @@ class SequenceGroupMetadata:
|
|||
self.multi_modal_data = multi_modal_data
|
||||
self.state = SequenceGroupState() if state is None else state
|
||||
self._token_chunk_size = token_chunk_size
|
||||
self.do_sample = do_sample
|
||||
|
||||
if self._token_chunk_size is None:
|
||||
if is_prompt:
|
||||
|
@ -650,6 +658,7 @@ class SequenceGroupOutput:
|
|||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
) -> None:
|
||||
self.samples = samples
|
||||
# Prompt logprob for each prompt query token.
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -10,9 +10,8 @@ from vllm.distributed import broadcast_tensor_dict
|
|||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import make_tensor_with_pad, maybe_expand_dim
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -38,6 +37,8 @@ class CPUModelRunner:
|
|||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
# Currently, CPU worker doesn't support chunked prefill.
|
||||
assert self.scheduler_config.chunked_prefill_enabled is False
|
||||
self.lora_config = lora_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.load_config = load_config
|
||||
|
@ -252,99 +253,6 @@ class CPUModelRunner:
|
|||
attn_metadata,
|
||||
)
|
||||
|
||||
def _prepare_sample(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
) -> SamplingMetadata:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||
selected_token_indices: List[int] = []
|
||||
generators: List[torch.Generator] = []
|
||||
selected_token_start_idx = 0
|
||||
categorized_sample_indices: Dict[SamplingType,
|
||||
List[Tuple[int, int]]] = {
|
||||
t: []
|
||||
for t in SamplingType
|
||||
}
|
||||
categorized_sample_indices_start_idx = 0
|
||||
categorized_sampled_token_indices_start_idx = 0
|
||||
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
|
||||
if seq_group_metadata.is_prompt:
|
||||
assert len(seq_ids) == 1
|
||||
subquery_len = prompt_lens[i]
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
# NOTE: prompt token positions do not need sample, skip
|
||||
categorized_sample_indices_start_idx += subquery_len - 1
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].append(
|
||||
(categorized_sample_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx))
|
||||
categorized_sample_indices_start_idx += 1
|
||||
categorized_sampled_token_indices_start_idx += 1
|
||||
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + subquery_len - 1))
|
||||
selected_token_indices.append(selected_token_start_idx +
|
||||
subquery_len - 1)
|
||||
selected_token_start_idx += subquery_len
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
seq_group_metadata.state.generator = torch.Generator(
|
||||
device=self.device).manual_seed(sampling_params.seed)
|
||||
else:
|
||||
num_seqs = len(seq_ids)
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + num_seqs))
|
||||
selected_token_start_idx += num_seqs
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].extend(
|
||||
zip(
|
||||
range(
|
||||
categorized_sample_indices_start_idx,
|
||||
categorized_sample_indices_start_idx +
|
||||
num_seqs),
|
||||
range(
|
||||
categorized_sampled_token_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx +
|
||||
num_seqs)))
|
||||
categorized_sample_indices_start_idx += num_seqs
|
||||
categorized_sampled_token_indices_start_idx += num_seqs
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
generators.append(seq_group_metadata.state.generator)
|
||||
|
||||
selected_token_indices = torch.tensor(selected_token_indices,
|
||||
dtype=torch.long)
|
||||
|
||||
categorized_sample_indices = {
|
||||
t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2)
|
||||
for t, seq_ids in categorized_sample_indices.items()
|
||||
}
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data.update(seq_group_metadata.seq_data)
|
||||
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_data=seq_data,
|
||||
prompt_lens=prompt_lens,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=categorized_sample_indices,
|
||||
generators=generators,
|
||||
)
|
||||
return sampling_metadata
|
||||
|
||||
def prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
|
@ -364,8 +272,15 @@ class CPUModelRunner:
|
|||
(input_tokens, input_positions,
|
||||
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
|
||||
prompt_lens = []
|
||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
# subquery_lens is not needed if chunked prefill is not
|
||||
# supported. Since CPU worker doesn't support chunked prefill
|
||||
# just use prompt_lens instead.
|
||||
prompt_lens,
|
||||
self.device,
|
||||
pin_memory=False)
|
||||
# Broadcast the metadata.
|
||||
metadata_dict = {
|
||||
"input_tokens": input_tokens,
|
||||
|
@ -389,7 +304,6 @@ class CPUModelRunner:
|
|||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
generators=None,
|
||||
perform_sampling=False,
|
||||
)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata,
|
||||
|
@ -421,7 +335,7 @@ class CPUModelRunner:
|
|||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not sampling_metadata.perform_sampling:
|
||||
if not self.is_driver_worker:
|
||||
return None
|
||||
|
||||
# Sample the next token.
|
||||
|
|
|
@ -20,12 +20,11 @@ from vllm.lora.request import LoRARequest
|
|||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
|
||||
is_pin_memory_available, make_tensor_with_pad,
|
||||
maybe_expand_dim)
|
||||
from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available,
|
||||
make_tensor_with_pad)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -547,108 +546,6 @@ class ModelRunner:
|
|||
slot_mapping=slot_mapping,
|
||||
)
|
||||
|
||||
def _prepare_sample(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
subquery_lens: Optional[List[int]],
|
||||
) -> SamplingMetadata:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||
selected_token_indices: List[int] = []
|
||||
generators: List[torch.Generator] = []
|
||||
selected_token_start_idx = 0
|
||||
categorized_sample_indices: Dict[SamplingType,
|
||||
List[Tuple[int, int]]] = {
|
||||
t: []
|
||||
for t in SamplingType
|
||||
}
|
||||
categorized_sample_indices_start_idx = 0
|
||||
categorized_sampled_token_indices_start_idx = 0
|
||||
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
|
||||
if seq_group_metadata.is_prompt:
|
||||
assert len(seq_ids) == 1
|
||||
assert subquery_lens is not None
|
||||
subquery_len = subquery_lens[i]
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
# NOTE: prompt token positions do not need sample, skip
|
||||
categorized_sample_indices_start_idx += subquery_len - 1
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].append(
|
||||
(categorized_sample_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx))
|
||||
categorized_sample_indices_start_idx += 1
|
||||
categorized_sampled_token_indices_start_idx += 1
|
||||
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + subquery_len - 1))
|
||||
selected_token_indices.append(selected_token_start_idx +
|
||||
subquery_len - 1)
|
||||
selected_token_start_idx += subquery_len
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
seq_group_metadata.state.generator = torch.Generator(
|
||||
device=self.device).manual_seed(sampling_params.seed)
|
||||
else:
|
||||
num_seqs = len(seq_ids)
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + num_seqs))
|
||||
selected_token_start_idx += num_seqs
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].extend(
|
||||
list(
|
||||
zip(
|
||||
range(
|
||||
categorized_sample_indices_start_idx,
|
||||
categorized_sample_indices_start_idx +
|
||||
num_seqs),
|
||||
range(
|
||||
categorized_sampled_token_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx
|
||||
+ num_seqs))))
|
||||
categorized_sample_indices_start_idx += num_seqs
|
||||
categorized_sampled_token_indices_start_idx += num_seqs
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
generators.append(seq_group_metadata.state.generator)
|
||||
|
||||
selected_token_indices = async_tensor_h2d(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
target_device=self.device,
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
categorized_sample_indices = {
|
||||
t: maybe_expand_dim(
|
||||
async_tensor_h2d(seq_ids,
|
||||
dtype=torch.int,
|
||||
target_device=self.device,
|
||||
pin_memory=self.pin_memory), 2, 2)
|
||||
for t, seq_ids in categorized_sample_indices.items()
|
||||
}
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data.update(seq_group_metadata.seq_data)
|
||||
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_data=seq_data,
|
||||
prompt_lens=prompt_lens,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=categorized_sample_indices,
|
||||
generators=generators,
|
||||
)
|
||||
return sampling_metadata
|
||||
|
||||
def prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
|
@ -685,9 +582,9 @@ class ModelRunner:
|
|||
decode_lora_requests,
|
||||
decode_slot_mapping,
|
||||
) = self._prepare_decode(decode_reqs)
|
||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, prompt_lens, subquery_lens,
|
||||
self.device, self.pin_memory)
|
||||
|
||||
if not self.scheduler_config.chunked_prefill_enabled:
|
||||
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
||||
|
@ -788,12 +685,9 @@ class ModelRunner:
|
|||
**metadata_dict)
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
seq_data=None,
|
||||
prompt_lens=None,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
generators=None,
|
||||
perform_sampling=False,
|
||||
num_prompts=0,
|
||||
)
|
||||
|
||||
# if it is a mixed batch, decode attn_metadata is broadcasted
|
||||
|
@ -852,7 +746,7 @@ class ModelRunner:
|
|||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not sampling_metadata.perform_sampling:
|
||||
if not self.is_driver_worker:
|
||||
return None
|
||||
|
||||
# Sample the next token.
|
||||
|
@ -860,6 +754,7 @@ class ModelRunner:
|
|||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -8,10 +8,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
|||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader.neuron import get_neuron_model
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
|
||||
make_tensor_with_pad, maybe_expand_dim)
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -141,106 +139,6 @@ class NeuronModelRunner:
|
|||
|
||||
return input_tokens, input_positions, input_block_ids
|
||||
|
||||
def _prepare_sample(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
) -> SamplingMetadata:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||
selected_token_indices: List[int] = []
|
||||
generators: List[torch.Generator] = []
|
||||
selected_token_start_idx = 0
|
||||
categorized_sample_indices: Dict[SamplingType,
|
||||
List[Tuple[int, int]]] = {
|
||||
t: []
|
||||
for t in SamplingType
|
||||
}
|
||||
categorized_sample_indices_start_idx = 0
|
||||
categorized_sampled_token_indices_start_idx = 0
|
||||
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
|
||||
if seq_group_metadata.is_prompt:
|
||||
assert len(seq_ids) == 1
|
||||
assert prompt_lens is not None
|
||||
prompt_len = prompt_lens[i]
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
# NOTE: prompt token positions do not need sample, skip
|
||||
categorized_sample_indices_start_idx += prompt_len - 1
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].append(
|
||||
(categorized_sample_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx))
|
||||
categorized_sample_indices_start_idx += 1
|
||||
categorized_sampled_token_indices_start_idx += 1
|
||||
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + prompt_len - 1))
|
||||
selected_token_indices.append(selected_token_start_idx +
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += prompt_len
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
seq_group_metadata.state.generator = torch.Generator(
|
||||
device=self.device).manual_seed(sampling_params.seed)
|
||||
else:
|
||||
num_seqs = len(seq_ids)
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + num_seqs))
|
||||
selected_token_start_idx += num_seqs
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].extend(
|
||||
zip(
|
||||
range(
|
||||
categorized_sample_indices_start_idx,
|
||||
categorized_sample_indices_start_idx +
|
||||
num_seqs),
|
||||
range(
|
||||
categorized_sampled_token_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx +
|
||||
num_seqs)))
|
||||
categorized_sample_indices_start_idx += num_seqs
|
||||
categorized_sampled_token_indices_start_idx += num_seqs
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
generators.append(seq_group_metadata.state.generator)
|
||||
|
||||
selected_token_indices = async_tensor_h2d(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
target_device=self.device,
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
categorized_sample_indices = {
|
||||
t: maybe_expand_dim(
|
||||
async_tensor_h2d(seq_ids,
|
||||
dtype=torch.int,
|
||||
target_device=self.device,
|
||||
pin_memory=self.pin_memory), 2, 2)
|
||||
for t, seq_ids in categorized_sample_indices.items()
|
||||
}
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data.update(seq_group_metadata.seq_data)
|
||||
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_data=seq_data,
|
||||
prompt_lens=prompt_lens,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=categorized_sample_indices,
|
||||
generators=generators,
|
||||
)
|
||||
return sampling_metadata
|
||||
|
||||
def prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
|
@ -256,8 +154,15 @@ class NeuronModelRunner:
|
|||
(input_tokens, input_positions,
|
||||
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
|
||||
prompt_lens = []
|
||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
# subquery_lens is not needed if chunked prefill is not
|
||||
# supported. Since neuron worker doesn't support chunked prefill
|
||||
# just use prompt_lens instead.
|
||||
prompt_lens,
|
||||
self.device,
|
||||
self.pin_memory)
|
||||
|
||||
return (input_tokens, input_positions, input_block_ids,
|
||||
sampling_metadata)
|
||||
|
|
Loading…
Reference in New Issue