[Core] Refactoring sampler and support prompt logprob for chunked prefill (#4309)

This commit is contained in:
SangBin Cho 2024-04-26 22:02:02 +09:00 committed by GitHub
parent a88081bf76
commit 603ad84815
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 859 additions and 630 deletions

View File

@ -9,15 +9,26 @@ MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
def test_get_prompt_logprobs(
hf_runner,
vllm_runner,
model,
dtype,
chunked_prefill_token_size: int,
num_top_logprobs: int,
example_prompts,
):
max_num_seqs = 256
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
max_num_batched_tokens = chunked_prefill_token_size
max_tokens = 5
num_top_logprobs = 6
hf_model = hf_runner(model, dtype=dtype)
hf_logprobs = hf_model.generate_greedy_logprobs(
example_prompts,
@ -25,10 +36,17 @@ def test_get_prompt_logprobs(
)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)
vllm_model = vllm_runner(
model,
dtype=dtype,
max_logprobs=num_top_logprobs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
)
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=5,
prompt_logprobs=num_top_logprobs,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)
@ -52,9 +70,18 @@ def test_get_prompt_logprobs(
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")
# The first prompt logprob is always None
assert result.prompt_logprobs[0] is None
for prompt_logprobs in result.prompt_logprobs[1:]:
# If the prompt token is not included in the top X
# logprob, it can return 1 more data
assert (len(prompt_logprobs) == num_top_logprobs
or len(prompt_logprobs) == num_top_logprobs + 1)
# Test whether prompt logprobs are consistent with HF
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
# Check prompt logprobs
# The first prompt logprob is always None, so we compare it from 1:.
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
for token_id, logprob in vllm_prompt_logprob_dict.items():
@ -74,6 +101,17 @@ def test_get_prompt_logprobs(
"The token should be decoded by the time it is returned "
" to the user.")
# Test if prompt logprobs are correctly set.
for vllm_result in vllm_results:
token_ids = vllm_result.prompt_token_ids
prompt_logprobs = vllm_result.prompt_logprobs
# The first token doesn't have logprob.
assert prompt_logprobs[0] is None
for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
assert token_id in logprob_dict
def test_max_logprobs():
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)

View File

@ -8,6 +8,7 @@ import torch
from transformers import GenerationConfig, GenerationMixin
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import Counter
@ -54,6 +55,7 @@ def _do_sample(
sampler: MockLogitsSampler,
model_runner: ModelRunner,
sampling_params: SamplingParams,
device: str,
):
seq_group_metadata_list = []
prompt_lens = []
@ -68,9 +70,12 @@ def _do_sample(
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
device=device,
pin_memory=model_runner.pin_memory)
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str):
sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)
sampling_params, device)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str):
n=random.randint(1, 10),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)
sampling_params, device)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)
sampling_params, device)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params)
model_runner, sampling_params, device)
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params)
model_runner, sampling_params, device)
assert first_sampler_output == second_sampler_output
@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str):
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params,
device)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
@ -443,10 +449,12 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
"batch size")
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_metadata = model_runner._prepare_sample(
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens=prompt_lens if prompt_lens else None,
subquery_lens=prompt_lens if prompt_lens else None)
subquery_lens=prompt_lens if prompt_lens else None,
device=device,
pin_memory=model_runner.pin_memory)
# the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
@ -530,8 +538,12 @@ def test_sampler_mixed(seed: int, device: str):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
def test_sampling(model_runner: ModelRunner):
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
device=device,
pin_memory=model_runner.pin_memory)
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)
@ -627,9 +639,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
device=device,
pin_memory=model_runner.pin_memory)
sample_probs = None

View File

@ -6,6 +6,7 @@ import pytest
import torch
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
@ -82,9 +83,12 @@ def test_logits_processors(seed: int, device: str):
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
logits_processor_output = logits_processor(
embedding=None,
hidden_states=input_tensor,

View File

@ -2,6 +2,7 @@ import pytest
import torch
from vllm.config import ModelConfig, SchedulerConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
@ -97,9 +98,12 @@ def test_prepare_prompt(batch_size):
assert len(input_positions) == sum(prompt_lens)
torch.testing.assert_close(input_tokens, input_positions)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
assert len(input_tokens) == sum(prompt_lens)
assert len(input_positions) == sum(prompt_lens)
actual = sampling_metadata.selected_token_indices
@ -195,9 +199,12 @@ def test_prepare_decode_cuda_graph(batch_size):
for prompt_len in prompt_lens:
expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,

View File

@ -915,6 +915,20 @@ class Scheduler:
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
do_sample = True
if seq_group.is_prefill():
seqs = seq_group.get_seqs()
# Prefill has only 1 sequence.
assert len(seqs) == 1
# In the next iteration, all prompt tokens are not computed.
# It means the prefill is chunked, and we don't need sampling.
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# output tokens.
if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
seqs[0].data.get_len()):
do_sample = False
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
is_prompt = seq_group.is_prefill()
@ -924,6 +938,7 @@ class Scheduler:
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
do_sample=do_sample,
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,

View File

@ -219,7 +219,7 @@ class _AsyncLLMEngine(LLMEngine):
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups)
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Log stats.
if self.log_stats:

View File

@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
SequenceGroup, SequenceStage)
SequenceGroup, SequenceGroupMetadata)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
@ -476,9 +476,12 @@ class LLMEngine:
return self.scheduler.has_unfinished_seqs()
def _process_model_outputs(
self, output: List[SamplerOutput],
scheduled_seq_groups: List[SequenceGroup],
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
self,
output: List[SamplerOutput],
scheduled_seq_groups: List[SequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> List[RequestOutput]:
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
@ -492,17 +495,15 @@ class LLMEngine:
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
output_by_sequence_group):
for scheduled_seq_group, outputs, seq_group_meta in zip(
scheduled_seq_groups, output_by_sequence_group,
seq_group_metadata_list):
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
# If all sequences in the sequence group are in DECODE, then we can
# process the output tokens. Otherwise, they are (chunked) prefill
# samples and should not be processed.
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
if all(stage == SequenceStage.DECODE for stage in stages):
self.output_processor.process_prompt_logprob(seq_group, outputs)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups.
@ -585,7 +586,7 @@ class LLMEngine:
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups)
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Log stats.
if self.log_stats:

View File

@ -68,3 +68,9 @@ class SequenceGroupOutputProcessor(ABC):
scheduler.
"""
pass
@abstractmethod
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Update prompt logprobs received from outputs to seq_group."""
pass

View File

@ -44,6 +44,15 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.stop_checker = stop_checker
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
# TODO(sang): Prompt logprob currently not implemented in multi step
# workers.
logger.warning(
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers).")
pass
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Append new tokens in the outputs to sequences in the sequence group.

View File

@ -55,17 +55,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
), f"{type(self)} does not support multiple outputs per step"
return self._process_sequence_group_outputs(sequence_group, outputs[0])
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None:
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None and \
seq_group.sampling_params.detokenize and self.detokenizer:
def process_prompt_logprob(self, seq_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0]
prompt_logprobs = output.prompt_logprobs
if (prompt_logprobs is not None
and seq_group.sampling_params.detokenize and self.detokenizer):
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
seq_group.prompt_logprobs = prompt_logprobs
if not seq_group.prompt_logprobs:
# The first prompt token's logprob is None because it doesn't
# have tokens that are precedent.
seq_group.prompt_logprobs = [None]
seq_group.prompt_logprobs.extend(prompt_logprobs)
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None:
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)

View File

@ -1,10 +1,11 @@
from typing import List
from vllm.sequence import SamplerOutput
from vllm.sequence import SamplerOutput, SequenceGroupOutput
def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
num_seq_groups: int):
def create_output_by_sequence_group(
sampler_outputs: List[SamplerOutput],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""

View File

@ -83,30 +83,27 @@ def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
logits_processed = 0
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
logits_processors = sampling_params.logits_processors
# handle prompt_logprobs by skipping rows in logits added for
# the prompt tokens (prompt logprobs are not processed)
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
assert len(seq_ids) == 1
logits_row_idx += sampling_metadata.prompt_lens[i] - 1
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
for seq_id, logits_row_idx in zip(seq_ids,
seq_group.sample_indices):
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
token_ids = seq_group.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
logits_processed += len(seq_group.sample_indices) + len(
seq_group.prompt_logprob_indices)
if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly
assert logits_row_idx == logits.shape[0]
assert logits_processed == logits.shape[0]
return logits

View File

@ -7,11 +7,11 @@ import torch.nn as nn
from vllm.model_executor.layers.ops.sample import sample as sample_triton
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors)
from vllm.sampling_params import SamplingParams, SamplingType
SamplingTensors,
SequenceGroupToSample)
from vllm.sampling_params import SamplingType
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)
SamplerOutput, SequenceGroupOutput, SequenceOutput)
class Sampler(nn.Module):
@ -48,11 +48,14 @@ class Sampler(nn.Module):
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert logits is not None
_, vocab_size = logits.shape
# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
# have not been generated yet
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Prepare sampling tensors with pinned memory to avoid blocking.
@ -83,7 +86,6 @@ class Sampler(nn.Module):
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
@ -149,24 +151,28 @@ def _apply_min_tokens_penalty(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
have not been generated yet
"""
# list of indices in logits that will be set to -inf
logits_to_penalize = []
start_idx = 0
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
logits_applied = 0
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
# handle prompt_logprobs by skipping rows in logits added for the prompt
# tokens (prompt logprobs are not penalized)
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
assert len(seq_ids) == 1
start_idx += sampling_metadata.prompt_lens[i] - 1
sample_indices = seq_group.sample_indices
logits_applied += len(sample_indices) + len(
seq_group.prompt_logprob_indices)
if not seq_group.do_sample:
continue
start_idx = sample_indices[0]
min_tokens = sampling_params.min_tokens
if min_tokens > 0:
seqs_to_penalize = []
for i, seq_id in enumerate(seq_ids):
seq_data = sampling_metadata.seq_data[seq_id]
seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens:
seqs_to_penalize.append(i)
@ -180,15 +186,13 @@ def _apply_min_tokens_penalty(
logits_to_penalize.extend(
itertools.product(seqs_to_penalize, token_ids_to_penalize))
start_idx += len(seq_ids)
if logits_to_penalize:
# use zip and * to group indices along each dimension
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
# verifies that no rows in logits were missed unexpectedly
assert start_idx == logits.shape[0]
assert logits_applied == logits.shape[0]
return logits
@ -265,14 +269,30 @@ def _apply_min_p(
def _greedy_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
selected_seq_groups: List[SequenceGroupToSample],
samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
"""Run greedy sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
samples: (num_selected_samples,) A tensor of samples. The length of
samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
samples = samples.tolist()
sample_idx = 0
results = []
for seq_group in selected_seq_groups:
seq_ids, _ = seq_group
if not seq_group.do_sample:
results.append(([], []))
continue
seq_ids = seq_group.seq_ids
num_parent_seqs = len(seq_ids)
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
@ -284,16 +304,33 @@ def _greedy_sample(
def _random_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
is_prompts: List[bool],
selected_seq_groups: List[SequenceGroupToSample],
random_samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
"""Run random sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
random_samples: (num_selected_samples,) A tensor of samples. The
length of samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum best_of value of the prompt phase requests.
random_samples = random_samples.cpu()
sample_idx = 0
results = []
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
seq_ids, sampling_params = seq_group
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
continue
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
is_prompt = seq_group.is_prompt
num_parent_seqs = len(seq_ids)
if is_prompt:
# Prompt phase.
@ -311,11 +348,20 @@ def _random_sample(
def _beam_search_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
is_prompts: List[bool],
seq_data: Dict[int, SequenceData],
selected_seq_groups: List[SequenceGroupToSample],
logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
"""Run beam sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
on selected sample indices.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# We sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
@ -327,8 +373,13 @@ def _beam_search_sample(
# other sampling methods.
sample_idx = 0
results = []
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
seq_ids, sampling_params = seq_group
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
continue
is_prompt = seq_group.is_prompt
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
num_parent_seqs = len(seq_ids)
beam_width = sampling_params.best_of
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
@ -343,7 +394,8 @@ def _beam_search_sample(
else:
# Generation phase.
cumulative_logprobs = [
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
seq_group.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids
]
cumulative_logprobs = torch.tensor(
cumulative_logprobs,
@ -371,8 +423,7 @@ def _beam_search_sample(
def _multinomial(
probs: torch.Tensor,
num_samples: int,
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
generators: Optional[List[torch.Generator]] = None,
seq_groups: Optional[List[SequenceGroupToSample]] = None,
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
@ -388,9 +439,11 @@ def _multinomial(
q.exponential_()
else:
sample_idx = 0
for (seq_ids, _), generator in zip(seq_groups, generators):
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
next_sample_idx = sample_idx + len(seq_ids) * num_samples
q[sample_idx:next_sample_idx].exponential_(generator=generator)
q[sample_idx:next_sample_idx].exponential_(
generator=seq_group.generator)
sample_idx = next_sample_idx
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
@ -405,7 +458,7 @@ def _sample_with_torch(
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
_, sampling_params = seq_group
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
@ -429,13 +482,11 @@ def _sample_with_torch(
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices)
long_sample_indices = sample_indices.long()
seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
long_sample_indices = sample_indices.long()
if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1)
@ -455,14 +506,13 @@ def _sample_with_torch(
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of_in_batch = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts):
if is_prompt:
_, sampling_params = seq_group
for seq_group in seq_groups:
if seq_group.is_prompt:
sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
"seq_groups": seq_groups,
"generators": sampling_metadata.generators,
}
multinomial_samples[sampling_type] = _multinomial(
@ -481,25 +531,22 @@ def _sample_with_torch(
# GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects.
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
sampling_type]
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups, is_prompts,
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts,
sampling_metadata.seq_data,
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_ids, sample_results))
sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [
sample_results_dict[i]
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results, sampled_token_ids_tensor
@ -514,7 +561,7 @@ def _sample_with_triton_kernel(
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
_, sampling_params = seq_group
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
@ -530,17 +577,16 @@ def _sample_with_triton_kernel(
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices,
seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
sample_metadata[sampling_type] = (seq_group_id, seq_groups,
sample_indices,
sampled_token_indices)
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
SamplingType.RANDOM_SEED):
for seq_group, is_prompt in zip(seq_groups, is_prompts):
if is_prompt:
_, sampling_params = seq_group
for seq_group in seq_groups:
if seq_group.is_prompt:
sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
elif sampling_type == SamplingType.BEAM:
@ -564,22 +610,21 @@ def _sample_with_triton_kernel(
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_ids, seq_groups, is_prompts, sample_indices,
(seq_group_id, seq_groups, sample_indices,
sampled_token_indices) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(
seq_groups, is_prompts, sampled_tokens[sampled_token_indices])
seq_groups, sampled_tokens[sampled_token_indices])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts,
sampling_metadata.seq_data,
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_ids, sample_results))
sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [
sample_results_dict[i]
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results
@ -590,6 +635,18 @@ def _sample(
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
"""
Args:
probs: (num_query_tokens_in_batch, num_vocab)
logprobs: (num_query_tokens_in_batch, num_vocab)
sampling_metadata: The metadata for a batch for sampling.
sampling_tensors: Tensors that include sampling related metadata.
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return _sample_with_torch(
probs,
logprobs,
@ -626,56 +683,97 @@ def _get_logprobs(
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sample_results: List[Tuple[List[int], List[int]]],
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
int, float]]]]:
# Prepare query indices
batched_logprobs_query_seq_indices: List[int] = []
batched_logprobs_query_token_indices: List[int] = []
# at least get one logprob for each token
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
"""Return sample lobprobs and prompt logprobs.
The logic consists of 3 parts.
- Select indices to compute logprob from, ranks of token ids, and
the top k token ids from logprobs.
- Compute prompt logprobs if required.
- Compute sample logprobs if required.
Args:
logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
logprob per vocab. Sequence groups' query tokens are batched in a
single flattened tensor. For example, assuming there are N
seq groups, it is sorted by prefill tokens for seq_group_1 (if
prompt logprob is enabled), decode tokens for seq_group_1 (if
sampling is required), prefill tokens for seq_group_2, ...
sampling_metadata: The sampling metadata.
sample_results: (num_seq_groups) The tuple of (next_token_ids,
parent_ids) for each sequence group. When beam search is enabled,
sample_results can contain different number of seq_ids from
sampling_metadata.seq_groups. It is because beam search creates
2 * BEAM_WIDTH number of samples (whereas there are only up to
BEAM_WIDTH number of seq_ids).
Returns:
A tuple of prompt and sample logprobs per sequence group in a batch.
"""
# The index of query token to calculate logprobs. It includes both
# prompt and sample logprob indices.
query_indices: List[int] = []
# The next token ids to get the logprob value from.
next_token_ids: List[int] = []
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API.
largest_num_logprobs = 1
sample_idx = 0
for i, (seq_group, sample_result) in enumerate(
zip(sampling_metadata.seq_groups, sample_results)):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_parent_seqs = len(seq_ids)
if (i < sampling_metadata.num_prompts
# Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs.
for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
sample_results):
sampling_params = seq_group.sampling_params
# Update indices and tokens for prompt logprobs.
if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None):
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.prompt_logprobs)
prompt_len = sampling_metadata.prompt_lens[i]
prompt_tokens = sampling_metadata.seq_data[
seq_ids[0]].prompt_token_ids
batched_logprobs_query_seq_indices.extend(
sample_idx + j for j in range(prompt_len - 1))
batched_logprobs_query_token_indices.extend(
token_id for token_id in prompt_tokens[1:])
sample_idx += prompt_len - 1
batched_logprobs_query_seq_indices.extend(
[sample_idx + parent_id for parent_id in parent_ids])
batched_logprobs_query_token_indices.extend(next_token_ids)
if sampling_params.logprobs is not None:
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.logprobs)
sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
query_indices.extend(seq_group.prompt_logprob_indices)
next_token_ids.extend(next_prompt_tokens)
batched_logprobs_query_seq_indices_gpu = torch.tensor(
batched_logprobs_query_seq_indices, device=logprobs.device)
batched_logprobs_query_token_indices_gpu = torch.tensor(
batched_logprobs_query_token_indices, device=logprobs.device)
# Update indices and next tokenes for sample logprob.
if seq_group.do_sample:
token_ids, parent_seq_ids = sample_result
# NOTE: We cannot directly use sample_indices because
# sample_indices only contain parent seq_ids of a previous step.
# The current step may have different number of seq_ids, and
# we can obtain it from `sample_result[1]`.
query_idx = seq_group.sample_indices[0]
query_indices.extend(
[query_idx + parent_id for parent_id in parent_seq_ids])
next_token_ids.extend(token_ids)
# Batched query for logprobs of selected token
batched_logprobs_query_result = logprobs[[
batched_logprobs_query_seq_indices_gpu,
batched_logprobs_query_token_indices_gpu
if sampling_params.logprobs is not None:
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.logprobs)
assert len(next_token_ids) == len(query_indices)
if len(query_indices) == 0:
empty_sampled_logprob = []
empty_prompt_logprob = None
return [empty_prompt_logprob], [empty_sampled_logprob]
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs = logprobs[[
query_indices_gpu,
next_token_ids_gpu,
]]
ranks = _get_ranks(
logprobs[query_indices_gpu],
next_token_ids_gpu,
)
assert selected_logprobs.shape[0] == ranks.shape[0]
batched_ranks_query_result = _get_ranks(
logprobs[batched_logprobs_query_seq_indices_gpu],
batched_logprobs_query_token_indices_gpu)
# Batched query for logprobs of topk tokens
# Logprobs of topk tokens for a batch of sequence groups.
# (num_query_tokens_across_batch).
if largest_num_logprobs > 0:
top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs,
@ -685,79 +783,136 @@ def _get_logprobs(
else:
top_logprobs, top_token_ids = None, None
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
batched_ranks_query_result = batched_ranks_query_result.cpu()
selected_logprobs = selected_logprobs.cpu()
ranks = ranks.cpu()
# Gather results
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
result_sample_logprobs: List[SampleLogprobs] = []
sample_idx = 0
query_result_idx = 0
for i, (seq_group, sample_result) in enumerate(
zip(sampling_metadata.seq_groups, sample_results)):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
sample_logprobs_per_seq_group: List[SampleLogprobs] = []
top_logprob_idx = 0
selected_logprobs_idx = 0
# Prompt logprobs
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
num_logprobs = sampling_params.prompt_logprobs
prompt_tokens = sampling_metadata.seq_data[
seq_ids[0]].prompt_token_ids
group_prompt_logprobs: PromptLogprobs = [None]
for token_id in prompt_tokens[1:]:
prompt_logprobs_dict = {
token_id:
(batched_logprobs_query_result[query_result_idx].item(),
batched_ranks_query_result[query_result_idx].item())
}
if num_logprobs > 0:
prompt_logprobs_dict.update(
zip(
top_token_ids[sample_idx, :num_logprobs].tolist(),
zip(
top_logprobs[
sample_idx, :num_logprobs].tolist(),
range(1, num_logprobs + 1))))
group_prompt_logprobs.append({
token_id: Logprob(*logprob_rank)
for token_id, logprob_rank in prompt_logprobs_dict.items()
})
sample_idx += 1
query_result_idx += 1
result_prompt_logprobs.append(group_prompt_logprobs)
else:
result_prompt_logprobs.append(None)
for seq_group, sample_result in zip(sampling_metadata.seq_groups,
sample_results):
(prompt_logprobs, top_logprob_idx,
selected_logprobs_idx) = _get_prompt_logprob_if_needed(
seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
selected_logprobs_idx, top_logprob_idx)
prompt_logprobs_per_seq_group.append(prompt_logprobs)
# Sample logprobs
num_logprobs = sampling_params.logprobs
if num_logprobs is None:
num_logprobs = 0
group_sample_logprobs: SampleLogprobs = []
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
sample_logprobs_dict = {
next_token_id:
(batched_logprobs_query_result[query_result_idx].item(),
batched_ranks_query_result[query_result_idx].item())
(sampled_logprobs, top_logprob_idx,
selected_logprobs_idx) = _get_sampled_logprob_if_needed(
seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
top_logprobs, selected_logprobs_idx, top_logprob_idx)
sample_logprobs_per_seq_group.append(sampled_logprobs)
return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
def _get_prompt_logprob_if_needed(
seq_group: SequenceGroupToSample,
selected_logprobs: torch.Tensor,
ranks: torch.Tensor,
top_token_ids: torch.Tensor,
top_logprobs: torch.Tensor,
selected_logprobs_idx: int,
top_logprob_idx: int,
):
"""Compute the prompt logprob from a sequence group if needed."""
sampling_params = seq_group.sampling_params
is_prompt = seq_group.is_prompt
# Find prompt logprobs
prompt_logprobs: Optional[PromptLogprobs] = None
if (is_prompt and sampling_params.prompt_logprobs is not None):
prompt_logprobs = []
num_logprobs = sampling_params.prompt_logprobs
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
for token_id in next_prompt_tokens:
# Calculate the prompt logprob of the real prompt tokens.
# Use tuple here for performance (to use to_list()).
# {token_id: (logprob, rank_from_vocab)}
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
token_id: (selected_logprobs[selected_logprobs_idx].item(),
ranks[selected_logprobs_idx].item())
}
query_result_idx += 1
if num_logprobs >= 0:
sample_logprobs_dict.update(
# Add top K prompt logprobs along with its rank.
if num_logprobs > 0:
prompt_logprobs_dict.update(
zip(
top_token_ids[sample_idx +
top_token_ids[top_logprob_idx, :num_logprobs].tolist(),
zip(
top_logprobs[
top_logprob_idx, :num_logprobs].tolist(),
# This is ranks. Since top_logprob is sorted,
# we can just use a range here.
range(1, num_logprobs + 1))))
prompt_logprobs.append({
token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in prompt_logprobs_dict.items()
})
# + 1 to go to the next prompt token.
top_logprob_idx += 1
selected_logprobs_idx += 1
return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
def _get_sampled_logprob_if_needed(
seq_group: SequenceGroupToSample,
sample_result: Tuple[List[int], List[int]],
selected_logprobs: torch.Tensor,
ranks: torch.Tensor,
top_token_ids: torch.Tensor,
top_logprobs: torch.Tensor,
selected_logprobs_idx: int,
top_logprob_idx: int,
):
"""Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs
if num_logprobs is None:
num_logprobs = 0
sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result
if seq_group.do_sample:
assert len(next_token_ids) > 0
for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids):
# Calculate the sample logprob of the real sampled tokens.
# Use tuple here for performance (to use to_list()).
# token_id: (logprob, rank_from_vocab)
sampled_logprobs_dict: Dict[int, Tuple[float, int]] = {
next_token_id:
(selected_logprobs[selected_logprobs_idx].item(),
ranks[selected_logprobs_idx].item())
}
# +1 to go to the next sampled token. Note that
# selected_logprobs can contain duplicates unlike top_logprobs
# when beam search is enabled.
selected_logprobs_idx += 1
# Second, add top K logprobs along with its rank.
if num_logprobs >= 0:
sampled_logprobs_dict.update(
zip(
top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist(),
zip(
top_logprobs[sample_idx +
top_logprobs[top_logprob_idx +
parent_id, :num_logprobs].tolist(),
# This is rank. Since top_logprob is sorted, we
# can just use a range here.
range(1, num_logprobs + 1))))
group_sample_logprobs.append({
token_id: Logprob(*logprob_rank)
for token_id, logprob_rank in sample_logprobs_dict.items()
sampled_logprobs.append({
token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in
sampled_logprobs_dict.items()
})
result_sample_logprobs.append(group_sample_logprobs)
sample_idx += len(seq_ids)
return result_prompt_logprobs, result_sample_logprobs
# There are len(seq_ids) number of sampled tokens for the current
# sequence group in top_logprobs. Jump to the next seq_group.
top_logprob_idx += len(seq_ids)
return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
@ -832,7 +987,7 @@ def _build_sampler_output(
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs,
sample_logprobs):
seq_ids, _ = seq_group
seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result
seq_outputs = []
for parent_id, next_token_id, logprobs in zip(parent_ids,
@ -854,3 +1009,36 @@ def _build_sampler_output(
sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids,
)
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
"""Get a list of next prompt tokens to compute logprob from a
given sequence group.
It is used to compute prompt logprob. Imagine you have logprob for each
query token. Query token needs to know the next prompt token id to compute
prompt logprob. This is a helper to obtain next prompt token ids.
This API has to be used only when the caller knows seq_group is in prefill
stage.
Returns:
A list of next prompt tokens to compute logprob.
"""
assert seq_group.is_prompt, (
"Caller should ensure the sequence group is in a prefill stage.")
seq_ids = seq_group.seq_ids
subquery_len = seq_group.subquery_len
assert subquery_len is not None
# prompt has only 1 seq id.
assert len(seq_ids) == 1
seq_data = seq_group.seq_data[seq_ids[0]]
computed_len = seq_data.get_num_computed_tokens()
prompt_tokens = seq_data.prompt_token_ids
# +1 because we are looking for a next prompt token.
next_token_index_start = computed_len + 1
next_token_index_end = min(computed_len + subquery_len + 1,
len(prompt_tokens))
next_prompt_tokens = prompt_tokens[
next_token_index_start:next_token_index_end]
return next_prompt_tokens

View File

@ -6,57 +6,275 @@ import torch
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
from vllm.utils import is_pin_memory_available
from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
maybe_expand_dim)
_SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558
@dataclass
class SequenceGroupToSample:
# Sequence ids for the sequence group in a previous step.
seq_ids: List[int]
sampling_params: SamplingParams
# seq_id -> sequence data.
seq_data: Dict[int, SequenceData]
# The length of the prompt of the sequence group. None if it is in a decode
# stage.
prompt_len: Optional[int]
# The length of the query tokens to compute in the current step. None if it
# is in a decode stage. The length of subquery_len <= prompt_len.
subquery_len: Optional[int]
# A random number generator for sampling.
generator: Optional[torch.Generator]
# True if the sequence group is in prefill stage. False if it is in a
# decode stage.
is_prompt: bool
# Query token indices from logits. to compute prompt logprob. Empty if
# prompt logprob is not required.
prompt_logprob_indices: List[int]
# Sample token indices from logits. Empty if sampling is not required.
sample_indices: List[int]
@property
def do_sample(self):
return len(self.sample_indices) > 0
def __post_init__(self):
if len(self.prompt_logprob_indices) > 0:
assert self.sampling_params.prompt_logprobs is not None
if self.is_prompt:
assert self.prompt_len is not None
assert self.subquery_len is not None
class SamplingMetadata:
"""Metadata for input sequences. Used in sampler.
The usage is as follow;
```
hidden_states = execute_model(...)
logits = hidden_states[sampling_metadata.selected_token_indices]
sample(logits)
def sample(logits):
# Use categorized_sample_indices for sampling....
```
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
seq_groups: List of batched sequence groups.
selected_token_indices: (num_query_tokens_to_logprob). Indices to find
logits from the initial model output hidden states.
categorized_sample_indices: SamplingType -> token indices to sample.
generators: List of torch.Generators to use for seeded sampling
perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable
sampling in other worker processes.
Each token indices is 2D tensor of (num_indices, num_indices) where
the first item means the sample index within the returned logit
(before pruning padding), and the second item means the sample
index after pruning using selected_token_indices.
For example, if the returned logit is [1, 2, 3], and we select
[1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
The first tuple is [1, 2] (sampled index within original logit),
and the second tuple is [0, 1] (sampled index within pruned logit).
num_prompts: Number of prompt sequence groups in seq_groups.
"""
def __init__(
self,
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
seq_data: Optional[Dict[int, SequenceData]],
prompt_lens: Optional[List[int]],
seq_groups: List[SequenceGroupToSample],
selected_token_indices: torch.Tensor,
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
generators: Optional[List[torch.Generator]] = None,
perform_sampling: bool = True,
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
num_prompts: int,
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.generators = generators
self.perform_sampling = perform_sampling
self.num_prompts = num_prompts
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
@staticmethod
def prepare(
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
subquery_lens: Optional[List[int]],
device: str,
pin_memory: bool,
) -> "SamplingMetadata":
(
seq_groups,
selected_token_indices,
categorized_sample_indices,
num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens,
subquery_lens, device)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=device,
pin_memory=pin_memory)
categorized_sample_indices = {
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
dtype=torch.int,
target_device=device,
pin_memory=pin_memory), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
num_prompts=num_prompts,
)
return sampling_metadata
def __repr__(self) -> str:
return (
"SamplingMetadata("
f"seq_groups={self.seq_groups}, "
f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens}, "
f"selected_token_indices={self.selected_token_indices}, "
f"categorized_sample_indices={self.categorized_sample_indices}), "
f"perform_sampling={self.perform_sampling})")
f"categorized_sample_indices={self.categorized_sample_indices}), ")
def _prepare_seq_groups(
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
subquery_lens: Optional[List[int]],
device: str,
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
SamplingType, List[Tuple[int, int]]], int]:
"""Prepare sequence groups and indices for sampling.
Args:
seq_group_metadata_list: A list of sequence group to batch.
prompt_lens: A list of prompt lens per sequence group.
Index of prompt len should match with seq_group_metadata_list.
subquery_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
`SequenceGroupToSample.generator`.
Returns:
seq_groups: A list of sequence group to sample.
selected_token_indices: See the definition from `SamplingMetadata`.
categorized_sample_indices: See the definition from `SamplingMetadata`.
num_prompts: Total number of prompts from `seq_group_metadata_list`.
"""
# Batched sequence groups for the current model forward stsep.
seq_groups: List[SequenceGroupToSample] = []
# A list of token indices to sample/compute logprob. It is used to
# prune the outcome logits from the model for the performance.
selected_token_indices: List[int] = []
# Used for selected_token_indices.
model_output_idx = 0
# Sampling type -> (
# indices to sample/prompt logprob within pruned output logits,
# indices to sample within pruned logits)
categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
# Index of logits to compute logprob. Logits include both prompt logprob
# and sample logprob indices.
logit_idx = 0
# Index to sample from a sample tensor. It is used by triton sample kernel.
# See `_sample_with_triton_kernel` for more details.
sample_idx = 0
# Total number of prompts from given sequence groups.
num_prompts = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
is_prompt = seq_group_metadata.is_prompt
generator: Optional[torch.Generator] = None
# If the current seq group is in decode stage, it is None.
prompt_len: Optional[int] = None
subquery_len: Optional[int] = None
prompt_logprob_indices: List[int] = []
sample_indices: List[int] = []
do_sample = seq_group_metadata.do_sample
if seq_group_metadata.is_prompt:
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=device).manual_seed(sampling_params.seed)
num_prompts += 1
num_prefill_sample = len(seq_ids)
assert num_prefill_sample == 1
assert subquery_lens is not None and prompt_lens is not None
subquery_len, prompt_len = subquery_lens[i], prompt_lens[i]
# If we need sampling, exclude num_prefill_sample tokens from
# prompt logprob.
prompt_logprob_len = (subquery_len - num_prefill_sample
if do_sample else subquery_len)
sample_len = num_prefill_sample if do_sample else 0
else:
# Decode
prompt_logprob_len = 0
sample_len = len(seq_ids) if do_sample else 0
# Update indices to select from the model output.
"""
This blocks computes selected_token_indices which is used in the
following way.
hidden_states = model(...)
logits = hidden_states[selected_token_indices]
"""
if sampling_params.prompt_logprobs:
selected_token_indices.extend(
range(model_output_idx, model_output_idx + prompt_logprob_len))
model_output_idx += prompt_logprob_len
if do_sample:
selected_token_indices.extend(
range(model_output_idx, model_output_idx + sample_len))
model_output_idx += sample_len
# We now find indices for logprob computation and sampling.
"""
This block computes categorized_sample_indices which is used in the
following way.
hidden_states = model(...)
logits = hidden_states[selected_token_indices]
def sample(logits):
# Use categorized_sample_indices for sampling.
# prompt_logprob_indices to find prompt logprob indices.
# sample_indices to find sample indices.
"""
if sampling_params.prompt_logprobs is not None:
prompt_logprob_indices.extend(
range(logit_idx, logit_idx + prompt_logprob_len))
logit_idx += prompt_logprob_len
if do_sample:
sample_indices.extend(range(logit_idx, logit_idx + sample_len))
categorized_sample_indices[sampling_params.sampling_type].extend(
list(
zip(range(logit_idx, logit_idx + sample_len),
range(sample_idx, sample_idx + sample_len))))
logit_idx += sample_len
sample_idx += sample_len
if sampling_params.seed is not None:
generator = seq_group_metadata.state.generator
seq_groups.append(
SequenceGroupToSample(
seq_ids=seq_ids,
sampling_params=sampling_params,
seq_data=seq_group_metadata.seq_data,
prompt_len=prompt_len,
subquery_len=subquery_len,
generator=generator,
is_prompt=is_prompt,
prompt_logprob_indices=list(prompt_logprob_indices),
sample_indices=list(sample_indices)))
return (seq_groups, selected_token_indices, categorized_sample_indices,
num_prompts)
@dataclass
@ -112,11 +330,10 @@ class SamplingTensors:
seeds_to_generate = (extra_seeds_to_generate +
get_num_triton_sampler_splits(vocab_size))
sample_indices_start_idx = 0
assert sampling_metadata.seq_groups is not None
assert sampling_metadata.seq_data is not None
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
temperature = sampling_params.temperature
p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty
@ -145,45 +362,46 @@ class SamplingTensors:
or abs(r - 1.0) >= _SAMPLING_EPS):
do_penalties = True
if (i < sampling_metadata.num_prompts
is_prompt = seq_group.is_prompt
if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get
# their logprobs
assert sampling_metadata.prompt_lens is not None
prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1)
top_ps += [top_p] * (prompt_len - 1)
top_ks += [top_k] * (prompt_len - 1)
min_ps += [min_p] * (prompt_len - 1)
presence_penalties += [0] * (prompt_len - 1)
frequency_penalties += [0] * (prompt_len - 1)
repetition_penalties += [1] * (prompt_len - 1)
prompt_tokens.extend([] for _ in range(prompt_len - 1))
output_tokens.extend([] for _ in range(prompt_len - 1))
for seq_id in seq_ids:
seq_data = sampling_metadata.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
temperatures += [temperature] * len(seq_ids)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
min_ps += [min_p] * len(seq_ids)
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
subquery_len = seq_group.subquery_len
assert subquery_len is not None
prefill_len = len(seq_group.prompt_logprob_indices)
temperatures += [temperature] * prefill_len
top_ps += [top_p] * prefill_len
top_ks += [top_k] * prefill_len
min_ps += [min_p] * prefill_len
presence_penalties += [0] * prefill_len
frequency_penalties += [0] * prefill_len
repetition_penalties += [1] * prefill_len
prompt_tokens.extend([] for _ in range(prefill_len))
output_tokens.extend([] for _ in range(prefill_len))
if seq_group.do_sample:
sample_lens = len(seq_group.sample_indices)
assert sample_lens == len(seq_ids)
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
temperatures += [temperature] * len(seq_ids)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
min_ps += [min_p] * len(seq_ids)
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
is_prompt = i < sampling_metadata.num_prompts
if is_prompt:
prompt_best_of.append(sampling_params.best_of)
assert sampling_metadata.prompt_lens is not None
prompt_len = sampling_metadata.prompt_lens[i]
subquery_len = seq_group.subquery_len
assert subquery_len is not None
if sampling_params.prompt_logprobs is not None:
# NOTE: the sampling position is the last token
# in the prompt
sample_indices_start_idx += prompt_len - 1
for seq_id in seq_ids:
seq_data = sampling_metadata.seq_data[seq_id]
seq_data = seq_group.seq_data[seq_id]
extra_entropy = extra_entropy or ()
seq_seeds = cls._get_sequence_seeds(
seed,
@ -193,8 +411,7 @@ class SamplingTensors:
seeds_to_generate=seeds_to_generate,
is_greedy=is_greedy)
sampling_seeds.append(seq_seeds)
sample_indices.append(sample_indices_start_idx)
sample_indices_start_idx += 1
sample_indices.extend(seq_group.sample_indices)
sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties,
@ -217,12 +434,14 @@ class SamplingTensors:
# Note that the performance will be very bad without
# pinned memory.
pin_memory = is_pin_memory_available()
prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
default=0)
prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens))
for tokens in prompt_tokens
]
output_max_len = max(len(tokens) for tokens in output_tokens)
output_max_len = max([len(tokens) for tokens in output_tokens],
default=0)
output_padded_tokens = [
tokens + [vocab_size] * (output_max_len - len(tokens))
for tokens in output_tokens

View File

@ -28,7 +28,10 @@ class Logprob:
decoded_token: Optional[str] = None
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
# {token_id -> logprob} for each sequence group.
SampleLogprobs = List[Dict[int, Logprob]]
@ -215,7 +218,7 @@ class Sequence:
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.data = SequenceData(prompt_token_ids)
self.data: SequenceData = SequenceData(prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
@ -559,6 +562,9 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
do_sample: True if sampling is required. Sampling is not required when
e.g., prefill is chunked, and the current iteration only computes
query tokens for prefill, we don't need sampling.
token_chunk_size: The number of tokens to be processed (per sequence).
None if chunking is not required.
state: Internal state tied to this sequence group.
@ -573,6 +579,7 @@ class SequenceGroupMetadata:
seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]],
do_sample: bool = True,
token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
@ -589,6 +596,7 @@ class SequenceGroupMetadata:
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
self._token_chunk_size = token_chunk_size
self.do_sample = do_sample
if self._token_chunk_size is None:
if is_prompt:
@ -650,6 +658,7 @@ class SequenceGroupOutput:
prompt_logprobs: Optional[PromptLogprobs],
) -> None:
self.samples = samples
# Prompt logprob for each prompt query token.
self.prompt_logprobs = prompt_logprobs
def __repr__(self) -> str:

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
@ -10,9 +10,8 @@ from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad, maybe_expand_dim
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
logger = init_logger(__name__)
@ -38,6 +37,8 @@ class CPUModelRunner:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
# Currently, CPU worker doesn't support chunked prefill.
assert self.scheduler_config.chunked_prefill_enabled is False
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.load_config = load_config
@ -252,99 +253,6 @@ class CPUModelRunner:
attn_metadata,
)
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
subquery_len = prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
(categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + subquery_len - 1))
selected_token_indices.append(selected_token_start_idx +
subquery_len - 1)
selected_token_start_idx += subquery_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=self.device).manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[
sampling_params.sampling_type].extend(
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx +
num_seqs)))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long)
categorized_sample_indices = {
t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
generators=generators,
)
return sampling_metadata
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
@ -364,8 +272,15 @@ class CPUModelRunner:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
# subquery_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use prompt_lens instead.
prompt_lens,
self.device,
pin_memory=False)
# Broadcast the metadata.
metadata_dict = {
"input_tokens": input_tokens,
@ -389,7 +304,6 @@ class CPUModelRunner:
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
generators=None,
perform_sampling=False,
)
return (input_tokens, input_positions, attn_metadata,
@ -421,7 +335,7 @@ class CPUModelRunner:
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
if not sampling_metadata.perform_sampling:
if not self.is_driver_worker:
return None
# Sample the next token.

View File

@ -20,12 +20,11 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
is_pin_memory_available, make_tensor_with_pad,
maybe_expand_dim)
from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available,
make_tensor_with_pad)
logger = init_logger(__name__)
@ -547,108 +546,6 @@ class ModelRunner:
slot_mapping=slot_mapping,
)
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
subquery_lens: Optional[List[int]],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
assert subquery_lens is not None
subquery_len = subquery_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
(categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + subquery_len - 1))
selected_token_indices.append(selected_token_start_idx +
subquery_len - 1)
selected_token_start_idx += subquery_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=self.device).manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[
sampling_params.sampling_type].extend(
list(
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx
+ num_seqs))))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=self.pin_memory)
categorized_sample_indices = {
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=self.pin_memory), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
generators=generators,
)
return sampling_metadata
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
@ -685,9 +582,9 @@ class ModelRunner:
decode_lora_requests,
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, prompt_lens, subquery_lens,
self.device, self.pin_memory)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
@ -788,12 +685,9 @@ class ModelRunner:
**metadata_dict)
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
prompt_lens=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
generators=None,
perform_sampling=False,
num_prompts=0,
)
# if it is a mixed batch, decode attn_metadata is broadcasted
@ -852,7 +746,7 @@ class ModelRunner:
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
if not sampling_metadata.perform_sampling:
if not self.is_driver_worker:
return None
# Sample the next token.
@ -860,6 +754,7 @@ class ModelRunner:
logits=logits,
sampling_metadata=sampling_metadata,
)
return output
@torch.inference_mode()

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
@ -8,10 +8,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
make_tensor_with_pad, maybe_expand_dim)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
logger = init_logger(__name__)
@ -141,106 +139,6 @@ class NeuronModelRunner:
return input_tokens, input_positions, input_block_ids
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
assert prompt_lens is not None
prompt_len = prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
(categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += prompt_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=self.device).manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[
sampling_params.sampling_type].extend(
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx +
num_seqs)))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=self.pin_memory)
categorized_sample_indices = {
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=self.pin_memory), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
generators=generators,
)
return sampling_metadata
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
@ -256,8 +154,15 @@ class NeuronModelRunner:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
# subquery_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use prompt_lens instead.
prompt_lens,
self.device,
self.pin_memory)
return (input_tokens, input_positions, input_block_ids,
sampling_metadata)