mirror of https://github.com/vllm-project/vllm
[Core] Refactoring sampler and support prompt logprob for chunked prefill (#4309)
This commit is contained in:
parent
a88081bf76
commit
603ad84815
|
@ -9,15 +9,26 @@ MODELS = ["facebook/opt-125m"]
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||||
|
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
|
||||||
def test_get_prompt_logprobs(
|
def test_get_prompt_logprobs(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
model,
|
model,
|
||||||
dtype,
|
dtype,
|
||||||
|
chunked_prefill_token_size: int,
|
||||||
|
num_top_logprobs: int,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
):
|
):
|
||||||
|
max_num_seqs = 256
|
||||||
|
enable_chunked_prefill = False
|
||||||
|
max_num_batched_tokens = None
|
||||||
|
if chunked_prefill_token_size != -1:
|
||||||
|
enable_chunked_prefill = True
|
||||||
|
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||||
|
max_num_batched_tokens = chunked_prefill_token_size
|
||||||
|
|
||||||
max_tokens = 5
|
max_tokens = 5
|
||||||
num_top_logprobs = 6
|
|
||||||
hf_model = hf_runner(model, dtype=dtype)
|
hf_model = hf_runner(model, dtype=dtype)
|
||||||
hf_logprobs = hf_model.generate_greedy_logprobs(
|
hf_logprobs = hf_model.generate_greedy_logprobs(
|
||||||
example_prompts,
|
example_prompts,
|
||||||
|
@ -25,10 +36,17 @@ def test_get_prompt_logprobs(
|
||||||
)
|
)
|
||||||
del hf_model
|
del hf_model
|
||||||
|
|
||||||
vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)
|
vllm_model = vllm_runner(
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_logprobs=num_top_logprobs,
|
||||||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
)
|
||||||
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
|
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||||
logprobs=num_top_logprobs,
|
logprobs=num_top_logprobs,
|
||||||
prompt_logprobs=5,
|
prompt_logprobs=num_top_logprobs,
|
||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
vllm_results = vllm_model.model.generate(
|
vllm_results = vllm_model.model.generate(
|
||||||
example_prompts, sampling_params=vllm_sampling_params)
|
example_prompts, sampling_params=vllm_sampling_params)
|
||||||
|
@ -52,9 +70,18 @@ def test_get_prompt_logprobs(
|
||||||
"The output text from the top logprob for each token position "
|
"The output text from the top logprob for each token position "
|
||||||
"should be the same as the output text in the result.")
|
"should be the same as the output text in the result.")
|
||||||
|
|
||||||
|
# The first prompt logprob is always None
|
||||||
|
assert result.prompt_logprobs[0] is None
|
||||||
|
for prompt_logprobs in result.prompt_logprobs[1:]:
|
||||||
|
# If the prompt token is not included in the top X
|
||||||
|
# logprob, it can return 1 more data
|
||||||
|
assert (len(prompt_logprobs) == num_top_logprobs
|
||||||
|
or len(prompt_logprobs) == num_top_logprobs + 1)
|
||||||
|
|
||||||
# Test whether prompt logprobs are consistent with HF
|
# Test whether prompt logprobs are consistent with HF
|
||||||
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
|
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
|
||||||
# Check prompt logprobs
|
# Check prompt logprobs
|
||||||
|
# The first prompt logprob is always None, so we compare it from 1:.
|
||||||
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
|
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
|
||||||
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
|
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
|
||||||
for token_id, logprob in vllm_prompt_logprob_dict.items():
|
for token_id, logprob in vllm_prompt_logprob_dict.items():
|
||||||
|
@ -74,6 +101,17 @@ def test_get_prompt_logprobs(
|
||||||
"The token should be decoded by the time it is returned "
|
"The token should be decoded by the time it is returned "
|
||||||
" to the user.")
|
" to the user.")
|
||||||
|
|
||||||
|
# Test if prompt logprobs are correctly set.
|
||||||
|
for vllm_result in vllm_results:
|
||||||
|
token_ids = vllm_result.prompt_token_ids
|
||||||
|
prompt_logprobs = vllm_result.prompt_logprobs
|
||||||
|
|
||||||
|
# The first token doesn't have logprob.
|
||||||
|
assert prompt_logprobs[0] is None
|
||||||
|
|
||||||
|
for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
|
||||||
|
assert token_id in logprob_dict
|
||||||
|
|
||||||
|
|
||||||
def test_max_logprobs():
|
def test_max_logprobs():
|
||||||
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
|
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
|
||||||
|
|
|
@ -8,6 +8,7 @@ import torch
|
||||||
from transformers import GenerationConfig, GenerationMixin
|
from transformers import GenerationConfig, GenerationMixin
|
||||||
|
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.utils import Counter
|
from vllm.utils import Counter
|
||||||
|
@ -54,6 +55,7 @@ def _do_sample(
|
||||||
sampler: MockLogitsSampler,
|
sampler: MockLogitsSampler,
|
||||||
model_runner: ModelRunner,
|
model_runner: ModelRunner,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
|
device: str,
|
||||||
):
|
):
|
||||||
seq_group_metadata_list = []
|
seq_group_metadata_list = []
|
||||||
prompt_lens = []
|
prompt_lens = []
|
||||||
|
@ -68,9 +70,12 @@ def _do_sample(
|
||||||
))
|
))
|
||||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
prompt_lens,
|
seq_group_metadata_list,
|
||||||
subquery_lens=prompt_lens)
|
prompt_lens,
|
||||||
|
subquery_lens=prompt_lens,
|
||||||
|
device=device,
|
||||||
|
pin_memory=model_runner.pin_memory)
|
||||||
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
|
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str):
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0)
|
sampling_params = SamplingParams(temperature=0)
|
||||||
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||||
sampling_params)
|
sampling_params, device)
|
||||||
expected = torch.argmax(fake_logits, dim=-1)
|
expected = torch.argmax(fake_logits, dim=-1)
|
||||||
for i, sequence_output in enumerate(sampler_output):
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
for nth_output in sequence_output.samples:
|
for nth_output in sequence_output.samples:
|
||||||
|
@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str):
|
||||||
n=random.randint(1, 10),
|
n=random.randint(1, 10),
|
||||||
)
|
)
|
||||||
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||||
sampling_params)
|
sampling_params, device)
|
||||||
|
|
||||||
for i, sequence_output in enumerate(sampler_output):
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
for nth_output in sequence_output.samples:
|
for nth_output in sequence_output.samples:
|
||||||
|
@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
|
||||||
seed=random.randint(0, 10000),
|
seed=random.randint(0, 10000),
|
||||||
)
|
)
|
||||||
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||||
sampling_params)
|
sampling_params, device)
|
||||||
|
|
||||||
for i, sequence_output in enumerate(sampler_output):
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
for nth_output in sequence_output.samples:
|
for nth_output in sequence_output.samples:
|
||||||
|
@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
|
||||||
seed=random.randint(0, 10000),
|
seed=random.randint(0, 10000),
|
||||||
)
|
)
|
||||||
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||||
model_runner, sampling_params)
|
model_runner, sampling_params, device)
|
||||||
|
|
||||||
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||||
model_runner, sampling_params)
|
model_runner, sampling_params, device)
|
||||||
|
|
||||||
assert first_sampler_output == second_sampler_output
|
assert first_sampler_output == second_sampler_output
|
||||||
|
|
||||||
|
@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str):
|
||||||
best_of=2,
|
best_of=2,
|
||||||
use_beam_search=True,
|
use_beam_search=True,
|
||||||
)
|
)
|
||||||
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
|
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params,
|
||||||
|
device)
|
||||||
# no assertion here as I am not sure how to determine whether
|
# no assertion here as I am not sure how to determine whether
|
||||||
# the outputs are expected - in other words, this just tests
|
# the outputs are expected - in other words, this just tests
|
||||||
# whether there are no exceptions in the sampler
|
# whether there are no exceptions in the sampler
|
||||||
|
@ -443,10 +449,12 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||||
"batch size")
|
"batch size")
|
||||||
|
|
||||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||||
sampling_metadata = model_runner._prepare_sample(
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
seq_group_metadata_list,
|
seq_group_metadata_list,
|
||||||
prompt_lens=prompt_lens if prompt_lens else None,
|
prompt_lens=prompt_lens if prompt_lens else None,
|
||||||
subquery_lens=prompt_lens if prompt_lens else None)
|
subquery_lens=prompt_lens if prompt_lens else None,
|
||||||
|
device=device,
|
||||||
|
pin_memory=model_runner.pin_memory)
|
||||||
# the logits tensor is modified in-place by the sampler
|
# the logits tensor is modified in-place by the sampler
|
||||||
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
|
@ -530,8 +538,12 @@ def test_sampler_mixed(seed: int, device: str):
|
||||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
def test_sampling(model_runner: ModelRunner):
|
def test_sampling(model_runner: ModelRunner):
|
||||||
sampling_metadata = model_runner._prepare_sample(
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
|
seq_group_metadata_list,
|
||||||
|
prompt_lens,
|
||||||
|
subquery_lens=prompt_lens,
|
||||||
|
device=device,
|
||||||
|
pin_memory=model_runner.pin_memory)
|
||||||
sampler_output = sampler(logits=fake_logits,
|
sampler_output = sampler(logits=fake_logits,
|
||||||
sampling_metadata=sampling_metadata)
|
sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
|
@ -627,9 +639,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||||
))
|
))
|
||||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
prompt_lens,
|
seq_group_metadata_list,
|
||||||
subquery_lens=prompt_lens)
|
prompt_lens,
|
||||||
|
subquery_lens=prompt_lens,
|
||||||
|
device=device,
|
||||||
|
pin_memory=model_runner.pin_memory)
|
||||||
|
|
||||||
sample_probs = None
|
sample_probs = None
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.worker.model_runner import ModelRunner
|
from vllm.worker.model_runner import ModelRunner
|
||||||
|
@ -82,9 +83,12 @@ def test_logits_processors(seed: int, device: str):
|
||||||
))
|
))
|
||||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
prompt_lens,
|
seq_group_metadata_list,
|
||||||
subquery_lens=prompt_lens)
|
prompt_lens,
|
||||||
|
subquery_lens=prompt_lens,
|
||||||
|
device=model_runner.device,
|
||||||
|
pin_memory=model_runner.pin_memory)
|
||||||
logits_processor_output = logits_processor(
|
logits_processor_output = logits_processor(
|
||||||
embedding=None,
|
embedding=None,
|
||||||
hidden_states=input_tensor,
|
hidden_states=input_tensor,
|
||||||
|
|
|
@ -2,6 +2,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import ModelConfig, SchedulerConfig
|
from vllm.config import ModelConfig, SchedulerConfig
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
||||||
|
|
||||||
|
@ -97,9 +98,12 @@ def test_prepare_prompt(batch_size):
|
||||||
assert len(input_positions) == sum(prompt_lens)
|
assert len(input_positions) == sum(prompt_lens)
|
||||||
torch.testing.assert_close(input_tokens, input_positions)
|
torch.testing.assert_close(input_tokens, input_positions)
|
||||||
|
|
||||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
prompt_lens,
|
seq_group_metadata_list,
|
||||||
subquery_lens=prompt_lens)
|
prompt_lens,
|
||||||
|
subquery_lens=prompt_lens,
|
||||||
|
device=model_runner.device,
|
||||||
|
pin_memory=model_runner.pin_memory)
|
||||||
assert len(input_tokens) == sum(prompt_lens)
|
assert len(input_tokens) == sum(prompt_lens)
|
||||||
assert len(input_positions) == sum(prompt_lens)
|
assert len(input_positions) == sum(prompt_lens)
|
||||||
actual = sampling_metadata.selected_token_indices
|
actual = sampling_metadata.selected_token_indices
|
||||||
|
@ -195,9 +199,12 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||||
for prompt_len in prompt_lens:
|
for prompt_len in prompt_lens:
|
||||||
expected_selected_token_indices.append(selected_token_start_idx)
|
expected_selected_token_indices.append(selected_token_start_idx)
|
||||||
selected_token_start_idx += 1
|
selected_token_start_idx += 1
|
||||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
prompt_lens,
|
seq_group_metadata_list,
|
||||||
subquery_lens=prompt_lens)
|
prompt_lens,
|
||||||
|
subquery_lens=prompt_lens,
|
||||||
|
device=model_runner.device,
|
||||||
|
pin_memory=model_runner.pin_memory)
|
||||||
actual = sampling_metadata.selected_token_indices
|
actual = sampling_metadata.selected_token_indices
|
||||||
expected = torch.tensor(expected_selected_token_indices,
|
expected = torch.tensor(expected_selected_token_indices,
|
||||||
device=actual.device,
|
device=actual.device,
|
||||||
|
|
|
@ -915,6 +915,20 @@ class Scheduler:
|
||||||
self.block_manager.get_common_computed_block_ids(
|
self.block_manager.get_common_computed_block_ids(
|
||||||
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
|
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
|
||||||
|
|
||||||
|
do_sample = True
|
||||||
|
if seq_group.is_prefill():
|
||||||
|
seqs = seq_group.get_seqs()
|
||||||
|
# Prefill has only 1 sequence.
|
||||||
|
assert len(seqs) == 1
|
||||||
|
# In the next iteration, all prompt tokens are not computed.
|
||||||
|
# It means the prefill is chunked, and we don't need sampling.
|
||||||
|
# NOTE: We use get_len instead of get_prompt_len because when
|
||||||
|
# a sequence is preempted, prefill includes previous generated
|
||||||
|
# output tokens.
|
||||||
|
if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
|
||||||
|
seqs[0].data.get_len()):
|
||||||
|
do_sample = False
|
||||||
|
|
||||||
# It assumes the scheduled_seq_groups is ordered by
|
# It assumes the scheduled_seq_groups is ordered by
|
||||||
# prefill < decoding.
|
# prefill < decoding.
|
||||||
is_prompt = seq_group.is_prefill()
|
is_prompt = seq_group.is_prefill()
|
||||||
|
@ -924,6 +938,7 @@ class Scheduler:
|
||||||
seq_data=seq_data,
|
seq_data=seq_data,
|
||||||
sampling_params=seq_group.sampling_params,
|
sampling_params=seq_group.sampling_params,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
|
do_sample=do_sample,
|
||||||
token_chunk_size=token_chunk_size,
|
token_chunk_size=token_chunk_size,
|
||||||
lora_request=seq_group.lora_request,
|
lora_request=seq_group.lora_request,
|
||||||
computed_block_nums=common_computed_block_nums,
|
computed_block_nums=common_computed_block_nums,
|
||||||
|
|
|
@ -219,7 +219,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||||
|
|
||||||
request_outputs = self._process_model_outputs(
|
request_outputs = self._process_model_outputs(
|
||||||
output, scheduler_outputs.scheduled_seq_groups,
|
output, scheduler_outputs.scheduled_seq_groups,
|
||||||
scheduler_outputs.ignored_seq_groups)
|
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||||
|
|
||||||
# Log stats.
|
# Log stats.
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
|
|
|
@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
||||||
SequenceGroup, SequenceStage)
|
SequenceGroup, SequenceGroupMetadata)
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||||
get_tokenizer_group)
|
get_tokenizer_group)
|
||||||
|
@ -476,9 +476,12 @@ class LLMEngine:
|
||||||
return self.scheduler.has_unfinished_seqs()
|
return self.scheduler.has_unfinished_seqs()
|
||||||
|
|
||||||
def _process_model_outputs(
|
def _process_model_outputs(
|
||||||
self, output: List[SamplerOutput],
|
self,
|
||||||
scheduled_seq_groups: List[SequenceGroup],
|
output: List[SamplerOutput],
|
||||||
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
|
scheduled_seq_groups: List[SequenceGroup],
|
||||||
|
ignored_seq_groups: List[SequenceGroup],
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> List[RequestOutput]:
|
||||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||||
|
|
||||||
Returns RequestOutputs that can be returned to the client.
|
Returns RequestOutputs that can be returned to the client.
|
||||||
|
@ -492,17 +495,15 @@ class LLMEngine:
|
||||||
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
|
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
|
||||||
|
|
||||||
# Update the scheduled sequence groups with the model outputs.
|
# Update the scheduled sequence groups with the model outputs.
|
||||||
for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
|
for scheduled_seq_group, outputs, seq_group_meta in zip(
|
||||||
output_by_sequence_group):
|
scheduled_seq_groups, output_by_sequence_group,
|
||||||
|
seq_group_metadata_list):
|
||||||
seq_group = scheduled_seq_group.seq_group
|
seq_group = scheduled_seq_group.seq_group
|
||||||
seq_group.update_num_computed_tokens(
|
seq_group.update_num_computed_tokens(
|
||||||
scheduled_seq_group.token_chunk_size)
|
scheduled_seq_group.token_chunk_size)
|
||||||
|
|
||||||
# If all sequences in the sequence group are in DECODE, then we can
|
self.output_processor.process_prompt_logprob(seq_group, outputs)
|
||||||
# process the output tokens. Otherwise, they are (chunked) prefill
|
if seq_group_meta.do_sample:
|
||||||
# samples and should not be processed.
|
|
||||||
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
|
|
||||||
if all(stage == SequenceStage.DECODE for stage in stages):
|
|
||||||
self.output_processor.process_outputs(seq_group, outputs)
|
self.output_processor.process_outputs(seq_group, outputs)
|
||||||
|
|
||||||
# Free the finished sequence groups.
|
# Free the finished sequence groups.
|
||||||
|
@ -585,7 +586,7 @@ class LLMEngine:
|
||||||
|
|
||||||
request_outputs = self._process_model_outputs(
|
request_outputs = self._process_model_outputs(
|
||||||
output, scheduler_outputs.scheduled_seq_groups,
|
output, scheduler_outputs.scheduled_seq_groups,
|
||||||
scheduler_outputs.ignored_seq_groups)
|
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||||
|
|
||||||
# Log stats.
|
# Log stats.
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
|
|
|
@ -68,3 +68,9 @@ class SequenceGroupOutputProcessor(ABC):
|
||||||
scheduler.
|
scheduler.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
||||||
|
outputs: List[SequenceGroupOutput]) -> None:
|
||||||
|
"""Update prompt logprobs received from outputs to seq_group."""
|
||||||
|
pass
|
||||||
|
|
|
@ -44,6 +44,15 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||||
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
||||||
self.stop_checker = stop_checker
|
self.stop_checker = stop_checker
|
||||||
|
|
||||||
|
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
||||||
|
outputs: List[SequenceGroupOutput]) -> None:
|
||||||
|
# TODO(sang): Prompt logprob currently not implemented in multi step
|
||||||
|
# workers.
|
||||||
|
logger.warning(
|
||||||
|
"Prompt logprob is not supported by multi step workers. "
|
||||||
|
"(e.g., speculative decode uses multi step workers).")
|
||||||
|
pass
|
||||||
|
|
||||||
def process_outputs(self, sequence_group: SequenceGroup,
|
def process_outputs(self, sequence_group: SequenceGroup,
|
||||||
outputs: List[SequenceGroupOutput]) -> None:
|
outputs: List[SequenceGroupOutput]) -> None:
|
||||||
"""Append new tokens in the outputs to sequences in the sequence group.
|
"""Append new tokens in the outputs to sequences in the sequence group.
|
||||||
|
|
|
@ -55,17 +55,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||||
), f"{type(self)} does not support multiple outputs per step"
|
), f"{type(self)} does not support multiple outputs per step"
|
||||||
return self._process_sequence_group_outputs(sequence_group, outputs[0])
|
return self._process_sequence_group_outputs(sequence_group, outputs[0])
|
||||||
|
|
||||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
||||||
outputs: SequenceGroupOutput) -> None:
|
outputs: List[SequenceGroupOutput]) -> None:
|
||||||
|
assert len(outputs) == 1, ("Single step should only has 1 output.")
|
||||||
# Process prompt logprobs
|
output = outputs[0]
|
||||||
prompt_logprobs = outputs.prompt_logprobs
|
prompt_logprobs = output.prompt_logprobs
|
||||||
if prompt_logprobs is not None and \
|
if (prompt_logprobs is not None
|
||||||
seq_group.sampling_params.detokenize and self.detokenizer:
|
and seq_group.sampling_params.detokenize and self.detokenizer):
|
||||||
self.detokenizer.decode_prompt_logprobs_inplace(
|
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||||
seq_group, prompt_logprobs)
|
seq_group, prompt_logprobs)
|
||||||
seq_group.prompt_logprobs = prompt_logprobs
|
if not seq_group.prompt_logprobs:
|
||||||
|
# The first prompt token's logprob is None because it doesn't
|
||||||
|
# have tokens that are precedent.
|
||||||
|
seq_group.prompt_logprobs = [None]
|
||||||
|
seq_group.prompt_logprobs.extend(prompt_logprobs)
|
||||||
|
|
||||||
|
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||||
|
outputs: SequenceGroupOutput) -> None:
|
||||||
# Process samples
|
# Process samples
|
||||||
samples = outputs.samples
|
samples = outputs.samples
|
||||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput, SequenceGroupOutput
|
||||||
|
|
||||||
|
|
||||||
def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
|
def create_output_by_sequence_group(
|
||||||
num_seq_groups: int):
|
sampler_outputs: List[SamplerOutput],
|
||||||
|
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
|
||||||
"""Helper method which transforms a 2d list organized by
|
"""Helper method which transforms a 2d list organized by
|
||||||
[step][sequence group] into [sequence group][step].
|
[step][sequence group] into [sequence group][step].
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -83,30 +83,27 @@ def _apply_logits_processors(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
logits_row_idx = 0
|
|
||||||
found_logits_processors = False
|
found_logits_processors = False
|
||||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
logits_processed = 0
|
||||||
seq_ids, sampling_params = seq_group
|
for seq_group in sampling_metadata.seq_groups:
|
||||||
|
seq_ids = seq_group.seq_ids
|
||||||
|
sampling_params = seq_group.sampling_params
|
||||||
logits_processors = sampling_params.logits_processors
|
logits_processors = sampling_params.logits_processors
|
||||||
# handle prompt_logprobs by skipping rows in logits added for
|
|
||||||
# the prompt tokens (prompt logprobs are not processed)
|
|
||||||
if (i < sampling_metadata.num_prompts
|
|
||||||
and sampling_params.prompt_logprobs is not None):
|
|
||||||
assert len(seq_ids) == 1
|
|
||||||
logits_row_idx += sampling_metadata.prompt_lens[i] - 1
|
|
||||||
|
|
||||||
if logits_processors:
|
if logits_processors:
|
||||||
found_logits_processors = True
|
found_logits_processors = True
|
||||||
for seq_id in seq_ids:
|
for seq_id, logits_row_idx in zip(seq_ids,
|
||||||
|
seq_group.sample_indices):
|
||||||
logits_row = logits[logits_row_idx]
|
logits_row = logits[logits_row_idx]
|
||||||
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
|
token_ids = seq_group.seq_data[seq_id].output_token_ids
|
||||||
for logits_processor in logits_processors:
|
for logits_processor in logits_processors:
|
||||||
logits_row = logits_processor(token_ids, logits_row)
|
logits_row = logits_processor(token_ids, logits_row)
|
||||||
logits[logits_row_idx] = logits_row
|
logits[logits_row_idx] = logits_row
|
||||||
logits_row_idx += 1
|
|
||||||
else:
|
logits_processed += len(seq_group.sample_indices) + len(
|
||||||
logits_row_idx += len(seq_ids)
|
seq_group.prompt_logprob_indices)
|
||||||
|
|
||||||
if found_logits_processors:
|
if found_logits_processors:
|
||||||
# verifies that no rows in logits were missed unexpectedly
|
# verifies that no rows in logits were missed unexpectedly
|
||||||
assert logits_row_idx == logits.shape[0]
|
assert logits_processed == logits.shape[0]
|
||||||
return logits
|
return logits
|
||||||
|
|
|
@ -7,11 +7,11 @@ import torch.nn as nn
|
||||||
|
|
||||||
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
||||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||||
SamplingTensors)
|
SamplingTensors,
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
SequenceGroupToSample)
|
||||||
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
||||||
SamplerOutput, SequenceData, SequenceGroupOutput,
|
SamplerOutput, SequenceGroupOutput, SequenceOutput)
|
||||||
SequenceOutput)
|
|
||||||
|
|
||||||
|
|
||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
|
@ -48,11 +48,14 @@ class Sampler(nn.Module):
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
logits: (num_tokens, vocab_size).
|
||||||
|
sampling_metadata: Metadata for sampling.
|
||||||
|
"""
|
||||||
assert logits is not None
|
assert logits is not None
|
||||||
_, vocab_size = logits.shape
|
_, vocab_size = logits.shape
|
||||||
|
|
||||||
# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
|
|
||||||
# have not been generated yet
|
|
||||||
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
||||||
|
|
||||||
# Prepare sampling tensors with pinned memory to avoid blocking.
|
# Prepare sampling tensors with pinned memory to avoid blocking.
|
||||||
|
@ -83,7 +86,6 @@ class Sampler(nn.Module):
|
||||||
# Compute the probabilities.
|
# Compute the probabilities.
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||||
# Compute the log probabilities.
|
# Compute the log probabilities.
|
||||||
# Use log_softmax to ensure numerical stability.
|
|
||||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens.
|
||||||
|
@ -149,24 +151,28 @@ def _apply_min_tokens_penalty(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
|
||||||
|
have not been generated yet
|
||||||
|
"""
|
||||||
# list of indices in logits that will be set to -inf
|
# list of indices in logits that will be set to -inf
|
||||||
logits_to_penalize = []
|
logits_to_penalize = []
|
||||||
start_idx = 0
|
logits_applied = 0
|
||||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
for seq_group in sampling_metadata.seq_groups:
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids = seq_group.seq_ids
|
||||||
|
sampling_params = seq_group.sampling_params
|
||||||
|
|
||||||
# handle prompt_logprobs by skipping rows in logits added for the prompt
|
sample_indices = seq_group.sample_indices
|
||||||
# tokens (prompt logprobs are not penalized)
|
logits_applied += len(sample_indices) + len(
|
||||||
if (i < sampling_metadata.num_prompts
|
seq_group.prompt_logprob_indices)
|
||||||
and sampling_params.prompt_logprobs is not None):
|
if not seq_group.do_sample:
|
||||||
assert len(seq_ids) == 1
|
continue
|
||||||
start_idx += sampling_metadata.prompt_lens[i] - 1
|
|
||||||
|
|
||||||
|
start_idx = sample_indices[0]
|
||||||
min_tokens = sampling_params.min_tokens
|
min_tokens = sampling_params.min_tokens
|
||||||
if min_tokens > 0:
|
if min_tokens > 0:
|
||||||
seqs_to_penalize = []
|
seqs_to_penalize = []
|
||||||
for i, seq_id in enumerate(seq_ids):
|
for i, seq_id in enumerate(seq_ids):
|
||||||
seq_data = sampling_metadata.seq_data[seq_id]
|
seq_data = seq_group.seq_data[seq_id]
|
||||||
if len(seq_data.output_token_ids) < min_tokens:
|
if len(seq_data.output_token_ids) < min_tokens:
|
||||||
seqs_to_penalize.append(i)
|
seqs_to_penalize.append(i)
|
||||||
|
|
||||||
|
@ -180,15 +186,13 @@ def _apply_min_tokens_penalty(
|
||||||
logits_to_penalize.extend(
|
logits_to_penalize.extend(
|
||||||
itertools.product(seqs_to_penalize, token_ids_to_penalize))
|
itertools.product(seqs_to_penalize, token_ids_to_penalize))
|
||||||
|
|
||||||
start_idx += len(seq_ids)
|
|
||||||
|
|
||||||
if logits_to_penalize:
|
if logits_to_penalize:
|
||||||
# use zip and * to group indices along each dimension
|
# use zip and * to group indices along each dimension
|
||||||
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
|
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
|
||||||
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
|
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
|
||||||
|
|
||||||
# verifies that no rows in logits were missed unexpectedly
|
# verifies that no rows in logits were missed unexpectedly
|
||||||
assert start_idx == logits.shape[0]
|
assert logits_applied == logits.shape[0]
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
@ -265,14 +269,30 @@ def _apply_min_p(
|
||||||
|
|
||||||
|
|
||||||
def _greedy_sample(
|
def _greedy_sample(
|
||||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
selected_seq_groups: List[SequenceGroupToSample],
|
||||||
samples: torch.Tensor,
|
samples: torch.Tensor,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
|
"""Run greedy sampling on a given samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
selected_seq_groups: A list of sequence groups batched.
|
||||||
|
samples: (num_selected_samples,) A tensor of samples. The length of
|
||||||
|
samples could be smaller than selected_seq_groups if
|
||||||
|
seq_group.do_sample is False.
|
||||||
|
Returns:
|
||||||
|
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
||||||
|
same as the length of selected_seq_groups. If the corresponding
|
||||||
|
seq_group has do_sample=False, tuple contains ([], [])
|
||||||
|
"""
|
||||||
samples = samples.tolist()
|
samples = samples.tolist()
|
||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
results = []
|
results = []
|
||||||
for seq_group in selected_seq_groups:
|
for seq_group in selected_seq_groups:
|
||||||
seq_ids, _ = seq_group
|
if not seq_group.do_sample:
|
||||||
|
results.append(([], []))
|
||||||
|
continue
|
||||||
|
|
||||||
|
seq_ids = seq_group.seq_ids
|
||||||
num_parent_seqs = len(seq_ids)
|
num_parent_seqs = len(seq_ids)
|
||||||
assert num_parent_seqs == 1, (
|
assert num_parent_seqs == 1, (
|
||||||
"Greedy sampling should have only one seq.")
|
"Greedy sampling should have only one seq.")
|
||||||
|
@ -284,16 +304,33 @@ def _greedy_sample(
|
||||||
|
|
||||||
|
|
||||||
def _random_sample(
|
def _random_sample(
|
||||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
selected_seq_groups: List[SequenceGroupToSample],
|
||||||
is_prompts: List[bool],
|
|
||||||
random_samples: torch.Tensor,
|
random_samples: torch.Tensor,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
|
"""Run random sampling on a given samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
selected_seq_groups: A list of sequence groups batched.
|
||||||
|
random_samples: (num_selected_samples,) A tensor of samples. The
|
||||||
|
length of samples could be smaller than selected_seq_groups if
|
||||||
|
seq_group.do_sample is False.
|
||||||
|
Returns:
|
||||||
|
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
||||||
|
same as the length of selected_seq_groups. If the corresponding
|
||||||
|
seq_group has do_sample=False, tuple contains ([], [])
|
||||||
|
"""
|
||||||
# Find the maximum best_of value of the prompt phase requests.
|
# Find the maximum best_of value of the prompt phase requests.
|
||||||
random_samples = random_samples.cpu()
|
random_samples = random_samples.cpu()
|
||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
results = []
|
results = []
|
||||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
for seq_group in selected_seq_groups:
|
||||||
seq_ids, sampling_params = seq_group
|
if not seq_group.do_sample:
|
||||||
|
results.append(([], []))
|
||||||
|
continue
|
||||||
|
|
||||||
|
seq_ids = seq_group.seq_ids
|
||||||
|
sampling_params = seq_group.sampling_params
|
||||||
|
is_prompt = seq_group.is_prompt
|
||||||
num_parent_seqs = len(seq_ids)
|
num_parent_seqs = len(seq_ids)
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
# Prompt phase.
|
# Prompt phase.
|
||||||
|
@ -311,11 +348,20 @@ def _random_sample(
|
||||||
|
|
||||||
|
|
||||||
def _beam_search_sample(
|
def _beam_search_sample(
|
||||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
selected_seq_groups: List[SequenceGroupToSample],
|
||||||
is_prompts: List[bool],
|
|
||||||
seq_data: Dict[int, SequenceData],
|
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
|
"""Run beam sampling on a given samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
selected_seq_groups: A list of sequence groups batched.
|
||||||
|
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
|
||||||
|
on selected sample indices.
|
||||||
|
Returns:
|
||||||
|
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
||||||
|
same as the length of selected_seq_groups. If the corresponding
|
||||||
|
seq_group has do_sample=False, tuple contains ([], [])
|
||||||
|
"""
|
||||||
# We sample 2 * beam_width candidates to make sure that with high
|
# We sample 2 * beam_width candidates to make sure that with high
|
||||||
# probability we can get `beam_width` candidates in addition to
|
# probability we can get `beam_width` candidates in addition to
|
||||||
# the finished sequences for the next iteration. See
|
# the finished sequences for the next iteration. See
|
||||||
|
@ -327,8 +373,13 @@ def _beam_search_sample(
|
||||||
# other sampling methods.
|
# other sampling methods.
|
||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
results = []
|
results = []
|
||||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
for seq_group in selected_seq_groups:
|
||||||
seq_ids, sampling_params = seq_group
|
if not seq_group.do_sample:
|
||||||
|
results.append(([], []))
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_prompt = seq_group.is_prompt
|
||||||
|
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
|
||||||
num_parent_seqs = len(seq_ids)
|
num_parent_seqs = len(seq_ids)
|
||||||
beam_width = sampling_params.best_of
|
beam_width = sampling_params.best_of
|
||||||
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
||||||
|
@ -343,7 +394,8 @@ def _beam_search_sample(
|
||||||
else:
|
else:
|
||||||
# Generation phase.
|
# Generation phase.
|
||||||
cumulative_logprobs = [
|
cumulative_logprobs = [
|
||||||
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
|
seq_group.seq_data[seq_id].cumulative_logprob
|
||||||
|
for seq_id in seq_ids
|
||||||
]
|
]
|
||||||
cumulative_logprobs = torch.tensor(
|
cumulative_logprobs = torch.tensor(
|
||||||
cumulative_logprobs,
|
cumulative_logprobs,
|
||||||
|
@ -371,8 +423,7 @@ def _beam_search_sample(
|
||||||
def _multinomial(
|
def _multinomial(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
num_samples: int,
|
num_samples: int,
|
||||||
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
|
seq_groups: Optional[List[SequenceGroupToSample]] = None,
|
||||||
generators: Optional[List[torch.Generator]] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if num_samples > 1:
|
if num_samples > 1:
|
||||||
# This is equivalent to torch.repeat_interleaved (which also
|
# This is equivalent to torch.repeat_interleaved (which also
|
||||||
|
@ -388,9 +439,11 @@ def _multinomial(
|
||||||
q.exponential_()
|
q.exponential_()
|
||||||
else:
|
else:
|
||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
for (seq_ids, _), generator in zip(seq_groups, generators):
|
for seq_group in seq_groups:
|
||||||
|
seq_ids = seq_group.seq_ids
|
||||||
next_sample_idx = sample_idx + len(seq_ids) * num_samples
|
next_sample_idx = sample_idx + len(seq_ids) * num_samples
|
||||||
q[sample_idx:next_sample_idx].exponential_(generator=generator)
|
q[sample_idx:next_sample_idx].exponential_(
|
||||||
|
generator=seq_group.generator)
|
||||||
sample_idx = next_sample_idx
|
sample_idx = next_sample_idx
|
||||||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||||
|
|
||||||
|
@ -405,7 +458,7 @@ def _sample_with_torch(
|
||||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
_, sampling_params = seq_group
|
sampling_params = seq_group.sampling_params
|
||||||
sampling_type = sampling_params.sampling_type
|
sampling_type = sampling_params.sampling_type
|
||||||
categorized_seq_group_ids[sampling_type].append(i)
|
categorized_seq_group_ids[sampling_type].append(i)
|
||||||
|
|
||||||
|
@ -429,13 +482,11 @@ def _sample_with_torch(
|
||||||
num_tokens = len(sample_indices)
|
num_tokens = len(sample_indices)
|
||||||
if num_tokens == 0:
|
if num_tokens == 0:
|
||||||
continue
|
continue
|
||||||
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
|
||||||
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
|
||||||
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
|
||||||
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
|
||||||
is_prompts, sample_indices)
|
|
||||||
long_sample_indices = sample_indices.long()
|
|
||||||
|
|
||||||
|
seq_group_id = categorized_seq_group_ids[sampling_type]
|
||||||
|
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
||||||
|
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
|
||||||
|
long_sample_indices = sample_indices.long()
|
||||||
if sampling_type == SamplingType.GREEDY:
|
if sampling_type == SamplingType.GREEDY:
|
||||||
greedy_samples = torch.argmax(logprobs[long_sample_indices],
|
greedy_samples = torch.argmax(logprobs[long_sample_indices],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
|
@ -455,14 +506,13 @@ def _sample_with_torch(
|
||||||
|
|
||||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||||
max_best_of_in_batch = 1
|
max_best_of_in_batch = 1
|
||||||
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
for seq_group in seq_groups:
|
||||||
if is_prompt:
|
if seq_group.is_prompt:
|
||||||
_, sampling_params = seq_group
|
sampling_params = seq_group.sampling_params
|
||||||
max_best_of_in_batch = max(max_best_of_in_batch,
|
max_best_of_in_batch = max(max_best_of_in_batch,
|
||||||
sampling_params.best_of)
|
sampling_params.best_of)
|
||||||
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
|
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
|
||||||
"seq_groups": seq_groups,
|
"seq_groups": seq_groups,
|
||||||
"generators": sampling_metadata.generators,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
multinomial_samples[sampling_type] = _multinomial(
|
multinomial_samples[sampling_type] = _multinomial(
|
||||||
|
@ -481,25 +531,22 @@ def _sample_with_torch(
|
||||||
|
|
||||||
# GPU<->CPU sync happens in the loop below.
|
# GPU<->CPU sync happens in the loop below.
|
||||||
# This also converts the sample output to Python objects.
|
# This also converts the sample output to Python objects.
|
||||||
|
|
||||||
for sampling_type in SamplingType:
|
for sampling_type in SamplingType:
|
||||||
if sampling_type not in sample_metadata:
|
if sampling_type not in sample_metadata:
|
||||||
continue
|
continue
|
||||||
seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
|
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
|
||||||
sampling_type]
|
|
||||||
if sampling_type == SamplingType.GREEDY:
|
if sampling_type == SamplingType.GREEDY:
|
||||||
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
||||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||||
sample_results = _random_sample(seq_groups, is_prompts,
|
sample_results = _random_sample(seq_groups,
|
||||||
multinomial_samples[sampling_type])
|
multinomial_samples[sampling_type])
|
||||||
elif sampling_type == SamplingType.BEAM:
|
elif sampling_type == SamplingType.BEAM:
|
||||||
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
sample_results = _beam_search_sample(seq_groups,
|
||||||
sampling_metadata.seq_data,
|
|
||||||
beam_search_logprobs)
|
beam_search_logprobs)
|
||||||
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
sample_results_dict.update(zip(seq_group_id, sample_results))
|
||||||
|
|
||||||
sample_results = [
|
sample_results = [
|
||||||
sample_results_dict[i]
|
sample_results_dict.get(i, ([], []))
|
||||||
for i in range(len(sampling_metadata.seq_groups))
|
for i in range(len(sampling_metadata.seq_groups))
|
||||||
]
|
]
|
||||||
return sample_results, sampled_token_ids_tensor
|
return sample_results, sampled_token_ids_tensor
|
||||||
|
@ -514,7 +561,7 @@ def _sample_with_triton_kernel(
|
||||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
_, sampling_params = seq_group
|
sampling_params = seq_group.sampling_params
|
||||||
sampling_type = sampling_params.sampling_type
|
sampling_type = sampling_params.sampling_type
|
||||||
categorized_seq_group_ids[sampling_type].append(i)
|
categorized_seq_group_ids[sampling_type].append(i)
|
||||||
|
|
||||||
|
@ -530,17 +577,16 @@ def _sample_with_triton_kernel(
|
||||||
num_tokens = len(sample_indices)
|
num_tokens = len(sample_indices)
|
||||||
if num_tokens == 0:
|
if num_tokens == 0:
|
||||||
continue
|
continue
|
||||||
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
seq_group_id = categorized_seq_group_ids[sampling_type]
|
||||||
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
||||||
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
sample_metadata[sampling_type] = (seq_group_id, seq_groups,
|
||||||
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
sample_indices,
|
||||||
is_prompts, sample_indices,
|
|
||||||
sampled_token_indices)
|
sampled_token_indices)
|
||||||
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
|
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
|
||||||
SamplingType.RANDOM_SEED):
|
SamplingType.RANDOM_SEED):
|
||||||
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
for seq_group in seq_groups:
|
||||||
if is_prompt:
|
if seq_group.is_prompt:
|
||||||
_, sampling_params = seq_group
|
sampling_params = seq_group.sampling_params
|
||||||
max_best_of_in_batch = max(max_best_of_in_batch,
|
max_best_of_in_batch = max(max_best_of_in_batch,
|
||||||
sampling_params.best_of)
|
sampling_params.best_of)
|
||||||
elif sampling_type == SamplingType.BEAM:
|
elif sampling_type == SamplingType.BEAM:
|
||||||
|
@ -564,22 +610,21 @@ def _sample_with_triton_kernel(
|
||||||
for sampling_type in SamplingType:
|
for sampling_type in SamplingType:
|
||||||
if sampling_type not in sample_metadata:
|
if sampling_type not in sample_metadata:
|
||||||
continue
|
continue
|
||||||
(seq_group_ids, seq_groups, is_prompts, sample_indices,
|
(seq_group_id, seq_groups, sample_indices,
|
||||||
sampled_token_indices) = sample_metadata[sampling_type]
|
sampled_token_indices) = sample_metadata[sampling_type]
|
||||||
if sampling_type == SamplingType.GREEDY:
|
if sampling_type == SamplingType.GREEDY:
|
||||||
sample_results = _greedy_sample(
|
sample_results = _greedy_sample(
|
||||||
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
|
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
|
||||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||||
sample_results = _random_sample(
|
sample_results = _random_sample(
|
||||||
seq_groups, is_prompts, sampled_tokens[sampled_token_indices])
|
seq_groups, sampled_tokens[sampled_token_indices])
|
||||||
elif sampling_type == SamplingType.BEAM:
|
elif sampling_type == SamplingType.BEAM:
|
||||||
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
sample_results = _beam_search_sample(seq_groups,
|
||||||
sampling_metadata.seq_data,
|
|
||||||
beam_search_logprobs)
|
beam_search_logprobs)
|
||||||
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
sample_results_dict.update(zip(seq_group_id, sample_results))
|
||||||
|
|
||||||
sample_results = [
|
sample_results = [
|
||||||
sample_results_dict[i]
|
sample_results_dict.get(i, ([], []))
|
||||||
for i in range(len(sampling_metadata.seq_groups))
|
for i in range(len(sampling_metadata.seq_groups))
|
||||||
]
|
]
|
||||||
return sample_results
|
return sample_results
|
||||||
|
@ -590,6 +635,18 @@ def _sample(
|
||||||
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
||||||
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
||||||
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
probs: (num_query_tokens_in_batch, num_vocab)
|
||||||
|
logprobs: (num_query_tokens_in_batch, num_vocab)
|
||||||
|
sampling_metadata: The metadata for a batch for sampling.
|
||||||
|
sampling_tensors: Tensors that include sampling related metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(next_token_ids, parent_seq_ids) for each seq group in a batch.
|
||||||
|
If sampling is skipped, it returns ([], [])
|
||||||
|
sampled_token_ids_tensor: A tensor of sampled token ids.
|
||||||
|
"""
|
||||||
return _sample_with_torch(
|
return _sample_with_torch(
|
||||||
probs,
|
probs,
|
||||||
logprobs,
|
logprobs,
|
||||||
|
@ -626,56 +683,97 @@ def _get_logprobs(
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
sample_results: List[Tuple[List[int], List[int]]],
|
sample_results: List[Tuple[List[int], List[int]]],
|
||||||
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
|
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
|
||||||
int, float]]]]:
|
"""Return sample lobprobs and prompt logprobs.
|
||||||
# Prepare query indices
|
|
||||||
batched_logprobs_query_seq_indices: List[int] = []
|
The logic consists of 3 parts.
|
||||||
batched_logprobs_query_token_indices: List[int] = []
|
- Select indices to compute logprob from, ranks of token ids, and
|
||||||
# at least get one logprob for each token
|
the top k token ids from logprobs.
|
||||||
|
- Compute prompt logprobs if required.
|
||||||
|
- Compute sample logprobs if required.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
|
||||||
|
logprob per vocab. Sequence groups' query tokens are batched in a
|
||||||
|
single flattened tensor. For example, assuming there are N
|
||||||
|
seq groups, it is sorted by prefill tokens for seq_group_1 (if
|
||||||
|
prompt logprob is enabled), decode tokens for seq_group_1 (if
|
||||||
|
sampling is required), prefill tokens for seq_group_2, ...
|
||||||
|
sampling_metadata: The sampling metadata.
|
||||||
|
sample_results: (num_seq_groups) The tuple of (next_token_ids,
|
||||||
|
parent_ids) for each sequence group. When beam search is enabled,
|
||||||
|
sample_results can contain different number of seq_ids from
|
||||||
|
sampling_metadata.seq_groups. It is because beam search creates
|
||||||
|
2 * BEAM_WIDTH number of samples (whereas there are only up to
|
||||||
|
BEAM_WIDTH number of seq_ids).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of prompt and sample logprobs per sequence group in a batch.
|
||||||
|
"""
|
||||||
|
# The index of query token to calculate logprobs. It includes both
|
||||||
|
# prompt and sample logprob indices.
|
||||||
|
query_indices: List[int] = []
|
||||||
|
# The next token ids to get the logprob value from.
|
||||||
|
next_token_ids: List[int] = []
|
||||||
|
# The largest requested number of logprobs. We find logprobs as many as the
|
||||||
|
# largest num logprobs in this API.
|
||||||
largest_num_logprobs = 1
|
largest_num_logprobs = 1
|
||||||
sample_idx = 0
|
|
||||||
for i, (seq_group, sample_result) in enumerate(
|
# Select indices to compute logprob from, ranks of token ids, and the top
|
||||||
zip(sampling_metadata.seq_groups, sample_results)):
|
# k token ids from logprobs.
|
||||||
seq_ids, sampling_params = seq_group
|
for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
|
||||||
next_token_ids, parent_ids = sample_result
|
sample_results):
|
||||||
num_parent_seqs = len(seq_ids)
|
sampling_params = seq_group.sampling_params
|
||||||
if (i < sampling_metadata.num_prompts
|
|
||||||
|
# Update indices and tokens for prompt logprobs.
|
||||||
|
if (seq_group.is_prompt
|
||||||
and sampling_params.prompt_logprobs is not None):
|
and sampling_params.prompt_logprobs is not None):
|
||||||
largest_num_logprobs = max(largest_num_logprobs,
|
largest_num_logprobs = max(largest_num_logprobs,
|
||||||
sampling_params.prompt_logprobs)
|
sampling_params.prompt_logprobs)
|
||||||
prompt_len = sampling_metadata.prompt_lens[i]
|
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
|
||||||
prompt_tokens = sampling_metadata.seq_data[
|
query_indices.extend(seq_group.prompt_logprob_indices)
|
||||||
seq_ids[0]].prompt_token_ids
|
next_token_ids.extend(next_prompt_tokens)
|
||||||
batched_logprobs_query_seq_indices.extend(
|
|
||||||
sample_idx + j for j in range(prompt_len - 1))
|
|
||||||
batched_logprobs_query_token_indices.extend(
|
|
||||||
token_id for token_id in prompt_tokens[1:])
|
|
||||||
sample_idx += prompt_len - 1
|
|
||||||
batched_logprobs_query_seq_indices.extend(
|
|
||||||
[sample_idx + parent_id for parent_id in parent_ids])
|
|
||||||
batched_logprobs_query_token_indices.extend(next_token_ids)
|
|
||||||
if sampling_params.logprobs is not None:
|
|
||||||
largest_num_logprobs = max(largest_num_logprobs,
|
|
||||||
sampling_params.logprobs)
|
|
||||||
sample_idx += num_parent_seqs
|
|
||||||
assert sample_idx == logprobs.size(0)
|
|
||||||
|
|
||||||
batched_logprobs_query_seq_indices_gpu = torch.tensor(
|
# Update indices and next tokenes for sample logprob.
|
||||||
batched_logprobs_query_seq_indices, device=logprobs.device)
|
if seq_group.do_sample:
|
||||||
batched_logprobs_query_token_indices_gpu = torch.tensor(
|
token_ids, parent_seq_ids = sample_result
|
||||||
batched_logprobs_query_token_indices, device=logprobs.device)
|
# NOTE: We cannot directly use sample_indices because
|
||||||
|
# sample_indices only contain parent seq_ids of a previous step.
|
||||||
|
# The current step may have different number of seq_ids, and
|
||||||
|
# we can obtain it from `sample_result[1]`.
|
||||||
|
query_idx = seq_group.sample_indices[0]
|
||||||
|
query_indices.extend(
|
||||||
|
[query_idx + parent_id for parent_id in parent_seq_ids])
|
||||||
|
next_token_ids.extend(token_ids)
|
||||||
|
|
||||||
# Batched query for logprobs of selected token
|
if sampling_params.logprobs is not None:
|
||||||
batched_logprobs_query_result = logprobs[[
|
largest_num_logprobs = max(largest_num_logprobs,
|
||||||
batched_logprobs_query_seq_indices_gpu,
|
sampling_params.logprobs)
|
||||||
batched_logprobs_query_token_indices_gpu
|
|
||||||
|
assert len(next_token_ids) == len(query_indices)
|
||||||
|
|
||||||
|
if len(query_indices) == 0:
|
||||||
|
empty_sampled_logprob = []
|
||||||
|
empty_prompt_logprob = None
|
||||||
|
return [empty_prompt_logprob], [empty_sampled_logprob]
|
||||||
|
|
||||||
|
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
|
||||||
|
next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
|
||||||
|
|
||||||
|
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
|
||||||
|
# contain duplicates if beam search is enabled.
|
||||||
|
selected_logprobs = logprobs[[
|
||||||
|
query_indices_gpu,
|
||||||
|
next_token_ids_gpu,
|
||||||
]]
|
]]
|
||||||
|
ranks = _get_ranks(
|
||||||
|
logprobs[query_indices_gpu],
|
||||||
|
next_token_ids_gpu,
|
||||||
|
)
|
||||||
|
assert selected_logprobs.shape[0] == ranks.shape[0]
|
||||||
|
|
||||||
batched_ranks_query_result = _get_ranks(
|
# Logprobs of topk tokens for a batch of sequence groups.
|
||||||
logprobs[batched_logprobs_query_seq_indices_gpu],
|
# (num_query_tokens_across_batch).
|
||||||
batched_logprobs_query_token_indices_gpu)
|
|
||||||
|
|
||||||
# Batched query for logprobs of topk tokens
|
|
||||||
if largest_num_logprobs > 0:
|
if largest_num_logprobs > 0:
|
||||||
top_logprobs, top_token_ids = torch.topk(logprobs,
|
top_logprobs, top_token_ids = torch.topk(logprobs,
|
||||||
largest_num_logprobs,
|
largest_num_logprobs,
|
||||||
|
@ -685,79 +783,136 @@ def _get_logprobs(
|
||||||
else:
|
else:
|
||||||
top_logprobs, top_token_ids = None, None
|
top_logprobs, top_token_ids = None, None
|
||||||
|
|
||||||
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
|
selected_logprobs = selected_logprobs.cpu()
|
||||||
batched_ranks_query_result = batched_ranks_query_result.cpu()
|
ranks = ranks.cpu()
|
||||||
|
|
||||||
# Gather results
|
# Find prompt/sample logprobs.
|
||||||
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
|
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
|
||||||
result_sample_logprobs: List[SampleLogprobs] = []
|
sample_logprobs_per_seq_group: List[SampleLogprobs] = []
|
||||||
sample_idx = 0
|
top_logprob_idx = 0
|
||||||
query_result_idx = 0
|
selected_logprobs_idx = 0
|
||||||
for i, (seq_group, sample_result) in enumerate(
|
|
||||||
zip(sampling_metadata.seq_groups, sample_results)):
|
|
||||||
seq_ids, sampling_params = seq_group
|
|
||||||
next_token_ids, parent_ids = sample_result
|
|
||||||
|
|
||||||
# Prompt logprobs
|
for seq_group, sample_result in zip(sampling_metadata.seq_groups,
|
||||||
if (i < sampling_metadata.num_prompts
|
sample_results):
|
||||||
and sampling_params.prompt_logprobs is not None):
|
(prompt_logprobs, top_logprob_idx,
|
||||||
num_logprobs = sampling_params.prompt_logprobs
|
selected_logprobs_idx) = _get_prompt_logprob_if_needed(
|
||||||
prompt_tokens = sampling_metadata.seq_data[
|
seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
|
||||||
seq_ids[0]].prompt_token_ids
|
selected_logprobs_idx, top_logprob_idx)
|
||||||
group_prompt_logprobs: PromptLogprobs = [None]
|
prompt_logprobs_per_seq_group.append(prompt_logprobs)
|
||||||
for token_id in prompt_tokens[1:]:
|
|
||||||
prompt_logprobs_dict = {
|
|
||||||
token_id:
|
|
||||||
(batched_logprobs_query_result[query_result_idx].item(),
|
|
||||||
batched_ranks_query_result[query_result_idx].item())
|
|
||||||
}
|
|
||||||
if num_logprobs > 0:
|
|
||||||
prompt_logprobs_dict.update(
|
|
||||||
zip(
|
|
||||||
top_token_ids[sample_idx, :num_logprobs].tolist(),
|
|
||||||
zip(
|
|
||||||
top_logprobs[
|
|
||||||
sample_idx, :num_logprobs].tolist(),
|
|
||||||
range(1, num_logprobs + 1))))
|
|
||||||
group_prompt_logprobs.append({
|
|
||||||
token_id: Logprob(*logprob_rank)
|
|
||||||
for token_id, logprob_rank in prompt_logprobs_dict.items()
|
|
||||||
})
|
|
||||||
sample_idx += 1
|
|
||||||
query_result_idx += 1
|
|
||||||
result_prompt_logprobs.append(group_prompt_logprobs)
|
|
||||||
else:
|
|
||||||
result_prompt_logprobs.append(None)
|
|
||||||
|
|
||||||
# Sample logprobs
|
(sampled_logprobs, top_logprob_idx,
|
||||||
num_logprobs = sampling_params.logprobs
|
selected_logprobs_idx) = _get_sampled_logprob_if_needed(
|
||||||
if num_logprobs is None:
|
seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
|
||||||
num_logprobs = 0
|
top_logprobs, selected_logprobs_idx, top_logprob_idx)
|
||||||
group_sample_logprobs: SampleLogprobs = []
|
sample_logprobs_per_seq_group.append(sampled_logprobs)
|
||||||
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
|
|
||||||
sample_logprobs_dict = {
|
return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
|
||||||
next_token_id:
|
|
||||||
(batched_logprobs_query_result[query_result_idx].item(),
|
|
||||||
batched_ranks_query_result[query_result_idx].item())
|
def _get_prompt_logprob_if_needed(
|
||||||
|
seq_group: SequenceGroupToSample,
|
||||||
|
selected_logprobs: torch.Tensor,
|
||||||
|
ranks: torch.Tensor,
|
||||||
|
top_token_ids: torch.Tensor,
|
||||||
|
top_logprobs: torch.Tensor,
|
||||||
|
selected_logprobs_idx: int,
|
||||||
|
top_logprob_idx: int,
|
||||||
|
):
|
||||||
|
"""Compute the prompt logprob from a sequence group if needed."""
|
||||||
|
sampling_params = seq_group.sampling_params
|
||||||
|
is_prompt = seq_group.is_prompt
|
||||||
|
|
||||||
|
# Find prompt logprobs
|
||||||
|
prompt_logprobs: Optional[PromptLogprobs] = None
|
||||||
|
if (is_prompt and sampling_params.prompt_logprobs is not None):
|
||||||
|
prompt_logprobs = []
|
||||||
|
num_logprobs = sampling_params.prompt_logprobs
|
||||||
|
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
|
||||||
|
for token_id in next_prompt_tokens:
|
||||||
|
# Calculate the prompt logprob of the real prompt tokens.
|
||||||
|
# Use tuple here for performance (to use to_list()).
|
||||||
|
# {token_id: (logprob, rank_from_vocab)}
|
||||||
|
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
|
||||||
|
token_id: (selected_logprobs[selected_logprobs_idx].item(),
|
||||||
|
ranks[selected_logprobs_idx].item())
|
||||||
}
|
}
|
||||||
query_result_idx += 1
|
|
||||||
if num_logprobs >= 0:
|
# Add top K prompt logprobs along with its rank.
|
||||||
sample_logprobs_dict.update(
|
if num_logprobs > 0:
|
||||||
|
prompt_logprobs_dict.update(
|
||||||
zip(
|
zip(
|
||||||
top_token_ids[sample_idx +
|
top_token_ids[top_logprob_idx, :num_logprobs].tolist(),
|
||||||
|
zip(
|
||||||
|
top_logprobs[
|
||||||
|
top_logprob_idx, :num_logprobs].tolist(),
|
||||||
|
# This is ranks. Since top_logprob is sorted,
|
||||||
|
# we can just use a range here.
|
||||||
|
range(1, num_logprobs + 1))))
|
||||||
|
prompt_logprobs.append({
|
||||||
|
token_id: Logprob(*logprob_and_rank)
|
||||||
|
for token_id, logprob_and_rank in prompt_logprobs_dict.items()
|
||||||
|
})
|
||||||
|
# + 1 to go to the next prompt token.
|
||||||
|
top_logprob_idx += 1
|
||||||
|
selected_logprobs_idx += 1
|
||||||
|
return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sampled_logprob_if_needed(
|
||||||
|
seq_group: SequenceGroupToSample,
|
||||||
|
sample_result: Tuple[List[int], List[int]],
|
||||||
|
selected_logprobs: torch.Tensor,
|
||||||
|
ranks: torch.Tensor,
|
||||||
|
top_token_ids: torch.Tensor,
|
||||||
|
top_logprobs: torch.Tensor,
|
||||||
|
selected_logprobs_idx: int,
|
||||||
|
top_logprob_idx: int,
|
||||||
|
):
|
||||||
|
"""Compute the sample logprob if needed."""
|
||||||
|
seq_ids = seq_group.seq_ids
|
||||||
|
num_logprobs = seq_group.sampling_params.logprobs
|
||||||
|
if num_logprobs is None:
|
||||||
|
num_logprobs = 0
|
||||||
|
sampled_logprobs: SampleLogprobs = []
|
||||||
|
next_token_ids, parent_seq_ids = sample_result
|
||||||
|
|
||||||
|
if seq_group.do_sample:
|
||||||
|
assert len(next_token_ids) > 0
|
||||||
|
for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids):
|
||||||
|
# Calculate the sample logprob of the real sampled tokens.
|
||||||
|
# Use tuple here for performance (to use to_list()).
|
||||||
|
# token_id: (logprob, rank_from_vocab)
|
||||||
|
sampled_logprobs_dict: Dict[int, Tuple[float, int]] = {
|
||||||
|
next_token_id:
|
||||||
|
(selected_logprobs[selected_logprobs_idx].item(),
|
||||||
|
ranks[selected_logprobs_idx].item())
|
||||||
|
}
|
||||||
|
# +1 to go to the next sampled token. Note that
|
||||||
|
# selected_logprobs can contain duplicates unlike top_logprobs
|
||||||
|
# when beam search is enabled.
|
||||||
|
selected_logprobs_idx += 1
|
||||||
|
|
||||||
|
# Second, add top K logprobs along with its rank.
|
||||||
|
if num_logprobs >= 0:
|
||||||
|
sampled_logprobs_dict.update(
|
||||||
|
zip(
|
||||||
|
top_token_ids[top_logprob_idx +
|
||||||
parent_id, :num_logprobs].tolist(),
|
parent_id, :num_logprobs].tolist(),
|
||||||
zip(
|
zip(
|
||||||
top_logprobs[sample_idx +
|
top_logprobs[top_logprob_idx +
|
||||||
parent_id, :num_logprobs].tolist(),
|
parent_id, :num_logprobs].tolist(),
|
||||||
|
# This is rank. Since top_logprob is sorted, we
|
||||||
|
# can just use a range here.
|
||||||
range(1, num_logprobs + 1))))
|
range(1, num_logprobs + 1))))
|
||||||
group_sample_logprobs.append({
|
sampled_logprobs.append({
|
||||||
token_id: Logprob(*logprob_rank)
|
token_id: Logprob(*logprob_and_rank)
|
||||||
for token_id, logprob_rank in sample_logprobs_dict.items()
|
for token_id, logprob_and_rank in
|
||||||
|
sampled_logprobs_dict.items()
|
||||||
})
|
})
|
||||||
result_sample_logprobs.append(group_sample_logprobs)
|
# There are len(seq_ids) number of sampled tokens for the current
|
||||||
sample_idx += len(seq_ids)
|
# sequence group in top_logprobs. Jump to the next seq_group.
|
||||||
|
top_logprob_idx += len(seq_ids)
|
||||||
return result_prompt_logprobs, result_sample_logprobs
|
return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
|
||||||
|
|
||||||
|
|
||||||
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
||||||
|
@ -832,7 +987,7 @@ def _build_sampler_output(
|
||||||
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
||||||
sample_results, prompt_logprobs,
|
sample_results, prompt_logprobs,
|
||||||
sample_logprobs):
|
sample_logprobs):
|
||||||
seq_ids, _ = seq_group
|
seq_ids = seq_group.seq_ids
|
||||||
next_token_ids, parent_ids = sample_result
|
next_token_ids, parent_ids = sample_result
|
||||||
seq_outputs = []
|
seq_outputs = []
|
||||||
for parent_id, next_token_id, logprobs in zip(parent_ids,
|
for parent_id, next_token_id, logprobs in zip(parent_ids,
|
||||||
|
@ -854,3 +1009,36 @@ def _build_sampler_output(
|
||||||
sampled_token_probs=sampled_token_probs,
|
sampled_token_probs=sampled_token_probs,
|
||||||
sampled_token_ids=sampled_token_ids,
|
sampled_token_ids=sampled_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
|
||||||
|
"""Get a list of next prompt tokens to compute logprob from a
|
||||||
|
given sequence group.
|
||||||
|
|
||||||
|
It is used to compute prompt logprob. Imagine you have logprob for each
|
||||||
|
query token. Query token needs to know the next prompt token id to compute
|
||||||
|
prompt logprob. This is a helper to obtain next prompt token ids.
|
||||||
|
|
||||||
|
This API has to be used only when the caller knows seq_group is in prefill
|
||||||
|
stage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of next prompt tokens to compute logprob.
|
||||||
|
"""
|
||||||
|
assert seq_group.is_prompt, (
|
||||||
|
"Caller should ensure the sequence group is in a prefill stage.")
|
||||||
|
seq_ids = seq_group.seq_ids
|
||||||
|
subquery_len = seq_group.subquery_len
|
||||||
|
assert subquery_len is not None
|
||||||
|
# prompt has only 1 seq id.
|
||||||
|
assert len(seq_ids) == 1
|
||||||
|
seq_data = seq_group.seq_data[seq_ids[0]]
|
||||||
|
computed_len = seq_data.get_num_computed_tokens()
|
||||||
|
prompt_tokens = seq_data.prompt_token_ids
|
||||||
|
# +1 because we are looking for a next prompt token.
|
||||||
|
next_token_index_start = computed_len + 1
|
||||||
|
next_token_index_end = min(computed_len + subquery_len + 1,
|
||||||
|
len(prompt_tokens))
|
||||||
|
next_prompt_tokens = prompt_tokens[
|
||||||
|
next_token_index_start:next_token_index_end]
|
||||||
|
return next_prompt_tokens
|
||||||
|
|
|
@ -6,57 +6,275 @@ import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
|
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData, SequenceGroupMetadata
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
|
||||||
|
maybe_expand_dim)
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
_SEED_0_REPLACEMENT = 3403598558
|
_SEED_0_REPLACEMENT = 3403598558
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SequenceGroupToSample:
|
||||||
|
# Sequence ids for the sequence group in a previous step.
|
||||||
|
seq_ids: List[int]
|
||||||
|
sampling_params: SamplingParams
|
||||||
|
# seq_id -> sequence data.
|
||||||
|
seq_data: Dict[int, SequenceData]
|
||||||
|
# The length of the prompt of the sequence group. None if it is in a decode
|
||||||
|
# stage.
|
||||||
|
prompt_len: Optional[int]
|
||||||
|
# The length of the query tokens to compute in the current step. None if it
|
||||||
|
# is in a decode stage. The length of subquery_len <= prompt_len.
|
||||||
|
subquery_len: Optional[int]
|
||||||
|
# A random number generator for sampling.
|
||||||
|
generator: Optional[torch.Generator]
|
||||||
|
# True if the sequence group is in prefill stage. False if it is in a
|
||||||
|
# decode stage.
|
||||||
|
is_prompt: bool
|
||||||
|
# Query token indices from logits. to compute prompt logprob. Empty if
|
||||||
|
# prompt logprob is not required.
|
||||||
|
prompt_logprob_indices: List[int]
|
||||||
|
# Sample token indices from logits. Empty if sampling is not required.
|
||||||
|
sample_indices: List[int]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def do_sample(self):
|
||||||
|
return len(self.sample_indices) > 0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if len(self.prompt_logprob_indices) > 0:
|
||||||
|
assert self.sampling_params.prompt_logprobs is not None
|
||||||
|
if self.is_prompt:
|
||||||
|
assert self.prompt_len is not None
|
||||||
|
assert self.subquery_len is not None
|
||||||
|
|
||||||
|
|
||||||
class SamplingMetadata:
|
class SamplingMetadata:
|
||||||
"""Metadata for input sequences. Used in sampler.
|
"""Metadata for input sequences. Used in sampler.
|
||||||
|
|
||||||
|
The usage is as follow;
|
||||||
|
```
|
||||||
|
hidden_states = execute_model(...)
|
||||||
|
logits = hidden_states[sampling_metadata.selected_token_indices]
|
||||||
|
sample(logits)
|
||||||
|
|
||||||
|
def sample(logits):
|
||||||
|
# Use categorized_sample_indices for sampling....
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq_groups: List of (seq_ids, sampling_params).
|
seq_groups: List of batched sequence groups.
|
||||||
seq_data: Seq_id -> SequenceData.
|
selected_token_indices: (num_query_tokens_to_logprob). Indices to find
|
||||||
prompt_lens: Lengths of prompts.
|
logits from the initial model output hidden states.
|
||||||
selected_token_indices: Token indices selected for sampling.
|
|
||||||
categorized_sample_indices: SamplingType -> token indices to sample.
|
categorized_sample_indices: SamplingType -> token indices to sample.
|
||||||
generators: List of torch.Generators to use for seeded sampling
|
Each token indices is 2D tensor of (num_indices, num_indices) where
|
||||||
perform_sampling: Whether to perform sampling. This option is used to
|
the first item means the sample index within the returned logit
|
||||||
make the sampling only happens in the driver worker, and disable
|
(before pruning padding), and the second item means the sample
|
||||||
sampling in other worker processes.
|
index after pruning using selected_token_indices.
|
||||||
|
For example, if the returned logit is [1, 2, 3], and we select
|
||||||
|
[1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
|
||||||
|
The first tuple is [1, 2] (sampled index within original logit),
|
||||||
|
and the second tuple is [0, 1] (sampled index within pruned logit).
|
||||||
|
num_prompts: Number of prompt sequence groups in seq_groups.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
|
seq_groups: List[SequenceGroupToSample],
|
||||||
seq_data: Optional[Dict[int, SequenceData]],
|
|
||||||
prompt_lens: Optional[List[int]],
|
|
||||||
selected_token_indices: torch.Tensor,
|
selected_token_indices: torch.Tensor,
|
||||||
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
|
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
||||||
generators: Optional[List[torch.Generator]] = None,
|
num_prompts: int,
|
||||||
perform_sampling: bool = True,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.seq_groups = seq_groups
|
self.seq_groups = seq_groups
|
||||||
self.seq_data = seq_data
|
|
||||||
self.prompt_lens = prompt_lens
|
|
||||||
self.selected_token_indices = selected_token_indices
|
self.selected_token_indices = selected_token_indices
|
||||||
self.categorized_sample_indices = categorized_sample_indices
|
self.categorized_sample_indices = categorized_sample_indices
|
||||||
self.generators = generators
|
self.num_prompts = num_prompts
|
||||||
self.perform_sampling = perform_sampling
|
|
||||||
|
|
||||||
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
|
@staticmethod
|
||||||
|
def prepare(
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
prompt_lens: List[int],
|
||||||
|
subquery_lens: Optional[List[int]],
|
||||||
|
device: str,
|
||||||
|
pin_memory: bool,
|
||||||
|
) -> "SamplingMetadata":
|
||||||
|
(
|
||||||
|
seq_groups,
|
||||||
|
selected_token_indices,
|
||||||
|
categorized_sample_indices,
|
||||||
|
num_prompts,
|
||||||
|
) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens,
|
||||||
|
subquery_lens, device)
|
||||||
|
selected_token_indices = async_tensor_h2d(selected_token_indices,
|
||||||
|
dtype=torch.long,
|
||||||
|
target_device=device,
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
categorized_sample_indices = {
|
||||||
|
t: maybe_expand_dim(
|
||||||
|
async_tensor_h2d(seq_ids,
|
||||||
|
dtype=torch.int,
|
||||||
|
target_device=device,
|
||||||
|
pin_memory=pin_memory), 2, 2)
|
||||||
|
for t, seq_ids in categorized_sample_indices.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_metadata = SamplingMetadata(
|
||||||
|
seq_groups=seq_groups,
|
||||||
|
selected_token_indices=selected_token_indices,
|
||||||
|
categorized_sample_indices=categorized_sample_indices,
|
||||||
|
num_prompts=num_prompts,
|
||||||
|
)
|
||||||
|
return sampling_metadata
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
"SamplingMetadata("
|
"SamplingMetadata("
|
||||||
f"seq_groups={self.seq_groups}, "
|
f"seq_groups={self.seq_groups}, "
|
||||||
f"seq_data={self.seq_data}, "
|
|
||||||
f"prompt_lens={self.prompt_lens}, "
|
|
||||||
f"selected_token_indices={self.selected_token_indices}, "
|
f"selected_token_indices={self.selected_token_indices}, "
|
||||||
f"categorized_sample_indices={self.categorized_sample_indices}), "
|
f"categorized_sample_indices={self.categorized_sample_indices}), ")
|
||||||
f"perform_sampling={self.perform_sampling})")
|
|
||||||
|
|
||||||
|
def _prepare_seq_groups(
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
prompt_lens: List[int],
|
||||||
|
subquery_lens: Optional[List[int]],
|
||||||
|
device: str,
|
||||||
|
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
|
||||||
|
SamplingType, List[Tuple[int, int]]], int]:
|
||||||
|
"""Prepare sequence groups and indices for sampling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_group_metadata_list: A list of sequence group to batch.
|
||||||
|
prompt_lens: A list of prompt lens per sequence group.
|
||||||
|
Index of prompt len should match with seq_group_metadata_list.
|
||||||
|
subquery_lens: A list of query lengths. Prompt lens include the length
|
||||||
|
of entire prompt tokens, and it could be shorter.
|
||||||
|
device: A device to use for random number generator,
|
||||||
|
`SequenceGroupToSample.generator`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
seq_groups: A list of sequence group to sample.
|
||||||
|
selected_token_indices: See the definition from `SamplingMetadata`.
|
||||||
|
categorized_sample_indices: See the definition from `SamplingMetadata`.
|
||||||
|
num_prompts: Total number of prompts from `seq_group_metadata_list`.
|
||||||
|
"""
|
||||||
|
# Batched sequence groups for the current model forward stsep.
|
||||||
|
seq_groups: List[SequenceGroupToSample] = []
|
||||||
|
# A list of token indices to sample/compute logprob. It is used to
|
||||||
|
# prune the outcome logits from the model for the performance.
|
||||||
|
selected_token_indices: List[int] = []
|
||||||
|
# Used for selected_token_indices.
|
||||||
|
model_output_idx = 0
|
||||||
|
|
||||||
|
# Sampling type -> (
|
||||||
|
# indices to sample/prompt logprob within pruned output logits,
|
||||||
|
# indices to sample within pruned logits)
|
||||||
|
categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
|
||||||
|
t: []
|
||||||
|
for t in SamplingType
|
||||||
|
}
|
||||||
|
# Index of logits to compute logprob. Logits include both prompt logprob
|
||||||
|
# and sample logprob indices.
|
||||||
|
logit_idx = 0
|
||||||
|
# Index to sample from a sample tensor. It is used by triton sample kernel.
|
||||||
|
# See `_sample_with_triton_kernel` for more details.
|
||||||
|
sample_idx = 0
|
||||||
|
# Total number of prompts from given sequence groups.
|
||||||
|
num_prompts = 0
|
||||||
|
|
||||||
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
|
is_prompt = seq_group_metadata.is_prompt
|
||||||
|
generator: Optional[torch.Generator] = None
|
||||||
|
# If the current seq group is in decode stage, it is None.
|
||||||
|
prompt_len: Optional[int] = None
|
||||||
|
subquery_len: Optional[int] = None
|
||||||
|
prompt_logprob_indices: List[int] = []
|
||||||
|
sample_indices: List[int] = []
|
||||||
|
do_sample = seq_group_metadata.do_sample
|
||||||
|
|
||||||
|
if seq_group_metadata.is_prompt:
|
||||||
|
if sampling_params.seed is not None:
|
||||||
|
seq_group_metadata.state.generator = torch.Generator(
|
||||||
|
device=device).manual_seed(sampling_params.seed)
|
||||||
|
|
||||||
|
num_prompts += 1
|
||||||
|
num_prefill_sample = len(seq_ids)
|
||||||
|
assert num_prefill_sample == 1
|
||||||
|
assert subquery_lens is not None and prompt_lens is not None
|
||||||
|
subquery_len, prompt_len = subquery_lens[i], prompt_lens[i]
|
||||||
|
# If we need sampling, exclude num_prefill_sample tokens from
|
||||||
|
# prompt logprob.
|
||||||
|
prompt_logprob_len = (subquery_len - num_prefill_sample
|
||||||
|
if do_sample else subquery_len)
|
||||||
|
sample_len = num_prefill_sample if do_sample else 0
|
||||||
|
else:
|
||||||
|
# Decode
|
||||||
|
prompt_logprob_len = 0
|
||||||
|
sample_len = len(seq_ids) if do_sample else 0
|
||||||
|
|
||||||
|
# Update indices to select from the model output.
|
||||||
|
"""
|
||||||
|
This blocks computes selected_token_indices which is used in the
|
||||||
|
following way.
|
||||||
|
|
||||||
|
hidden_states = model(...)
|
||||||
|
logits = hidden_states[selected_token_indices]
|
||||||
|
"""
|
||||||
|
|
||||||
|
if sampling_params.prompt_logprobs:
|
||||||
|
selected_token_indices.extend(
|
||||||
|
range(model_output_idx, model_output_idx + prompt_logprob_len))
|
||||||
|
model_output_idx += prompt_logprob_len
|
||||||
|
if do_sample:
|
||||||
|
selected_token_indices.extend(
|
||||||
|
range(model_output_idx, model_output_idx + sample_len))
|
||||||
|
model_output_idx += sample_len
|
||||||
|
|
||||||
|
# We now find indices for logprob computation and sampling.
|
||||||
|
"""
|
||||||
|
This block computes categorized_sample_indices which is used in the
|
||||||
|
following way.
|
||||||
|
|
||||||
|
hidden_states = model(...)
|
||||||
|
logits = hidden_states[selected_token_indices]
|
||||||
|
def sample(logits):
|
||||||
|
# Use categorized_sample_indices for sampling.
|
||||||
|
# prompt_logprob_indices to find prompt logprob indices.
|
||||||
|
# sample_indices to find sample indices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if sampling_params.prompt_logprobs is not None:
|
||||||
|
prompt_logprob_indices.extend(
|
||||||
|
range(logit_idx, logit_idx + prompt_logprob_len))
|
||||||
|
logit_idx += prompt_logprob_len
|
||||||
|
if do_sample:
|
||||||
|
sample_indices.extend(range(logit_idx, logit_idx + sample_len))
|
||||||
|
categorized_sample_indices[sampling_params.sampling_type].extend(
|
||||||
|
list(
|
||||||
|
zip(range(logit_idx, logit_idx + sample_len),
|
||||||
|
range(sample_idx, sample_idx + sample_len))))
|
||||||
|
logit_idx += sample_len
|
||||||
|
sample_idx += sample_len
|
||||||
|
|
||||||
|
if sampling_params.seed is not None:
|
||||||
|
generator = seq_group_metadata.state.generator
|
||||||
|
|
||||||
|
seq_groups.append(
|
||||||
|
SequenceGroupToSample(
|
||||||
|
seq_ids=seq_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
seq_data=seq_group_metadata.seq_data,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
subquery_len=subquery_len,
|
||||||
|
generator=generator,
|
||||||
|
is_prompt=is_prompt,
|
||||||
|
prompt_logprob_indices=list(prompt_logprob_indices),
|
||||||
|
sample_indices=list(sample_indices)))
|
||||||
|
return (seq_groups, selected_token_indices, categorized_sample_indices,
|
||||||
|
num_prompts)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -112,11 +330,10 @@ class SamplingTensors:
|
||||||
seeds_to_generate = (extra_seeds_to_generate +
|
seeds_to_generate = (extra_seeds_to_generate +
|
||||||
get_num_triton_sampler_splits(vocab_size))
|
get_num_triton_sampler_splits(vocab_size))
|
||||||
|
|
||||||
sample_indices_start_idx = 0
|
|
||||||
assert sampling_metadata.seq_groups is not None
|
assert sampling_metadata.seq_groups is not None
|
||||||
assert sampling_metadata.seq_data is not None
|
for seq_group in sampling_metadata.seq_groups:
|
||||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
seq_ids = seq_group.seq_ids
|
||||||
seq_ids, sampling_params = seq_group
|
sampling_params = seq_group.sampling_params
|
||||||
temperature = sampling_params.temperature
|
temperature = sampling_params.temperature
|
||||||
p = sampling_params.presence_penalty
|
p = sampling_params.presence_penalty
|
||||||
f = sampling_params.frequency_penalty
|
f = sampling_params.frequency_penalty
|
||||||
|
@ -145,45 +362,46 @@ class SamplingTensors:
|
||||||
or abs(r - 1.0) >= _SAMPLING_EPS):
|
or abs(r - 1.0) >= _SAMPLING_EPS):
|
||||||
do_penalties = True
|
do_penalties = True
|
||||||
|
|
||||||
if (i < sampling_metadata.num_prompts
|
is_prompt = seq_group.is_prompt
|
||||||
|
if (seq_group.is_prompt
|
||||||
and sampling_params.prompt_logprobs is not None):
|
and sampling_params.prompt_logprobs is not None):
|
||||||
# For tokens in the prompt that we only need to get
|
# For tokens in the prompt that we only need to get
|
||||||
# their logprobs
|
# their logprobs
|
||||||
assert sampling_metadata.prompt_lens is not None
|
subquery_len = seq_group.subquery_len
|
||||||
prompt_len = sampling_metadata.prompt_lens[i]
|
assert subquery_len is not None
|
||||||
temperatures += [temperature] * (prompt_len - 1)
|
prefill_len = len(seq_group.prompt_logprob_indices)
|
||||||
top_ps += [top_p] * (prompt_len - 1)
|
temperatures += [temperature] * prefill_len
|
||||||
top_ks += [top_k] * (prompt_len - 1)
|
top_ps += [top_p] * prefill_len
|
||||||
min_ps += [min_p] * (prompt_len - 1)
|
top_ks += [top_k] * prefill_len
|
||||||
presence_penalties += [0] * (prompt_len - 1)
|
min_ps += [min_p] * prefill_len
|
||||||
frequency_penalties += [0] * (prompt_len - 1)
|
presence_penalties += [0] * prefill_len
|
||||||
repetition_penalties += [1] * (prompt_len - 1)
|
frequency_penalties += [0] * prefill_len
|
||||||
prompt_tokens.extend([] for _ in range(prompt_len - 1))
|
repetition_penalties += [1] * prefill_len
|
||||||
output_tokens.extend([] for _ in range(prompt_len - 1))
|
prompt_tokens.extend([] for _ in range(prefill_len))
|
||||||
for seq_id in seq_ids:
|
output_tokens.extend([] for _ in range(prefill_len))
|
||||||
seq_data = sampling_metadata.seq_data[seq_id]
|
|
||||||
prompt_tokens.append(seq_data.prompt_token_ids)
|
if seq_group.do_sample:
|
||||||
output_tokens.append(seq_data.output_token_ids)
|
sample_lens = len(seq_group.sample_indices)
|
||||||
temperatures += [temperature] * len(seq_ids)
|
assert sample_lens == len(seq_ids)
|
||||||
top_ps += [top_p] * len(seq_ids)
|
for seq_id in seq_ids:
|
||||||
top_ks += [top_k] * len(seq_ids)
|
seq_data = seq_group.seq_data[seq_id]
|
||||||
min_ps += [min_p] * len(seq_ids)
|
prompt_tokens.append(seq_data.prompt_token_ids)
|
||||||
presence_penalties += [p] * len(seq_ids)
|
output_tokens.append(seq_data.output_token_ids)
|
||||||
frequency_penalties += [f] * len(seq_ids)
|
temperatures += [temperature] * len(seq_ids)
|
||||||
repetition_penalties += [r] * len(seq_ids)
|
top_ps += [top_p] * len(seq_ids)
|
||||||
|
top_ks += [top_k] * len(seq_ids)
|
||||||
|
min_ps += [min_p] * len(seq_ids)
|
||||||
|
presence_penalties += [p] * len(seq_ids)
|
||||||
|
frequency_penalties += [f] * len(seq_ids)
|
||||||
|
repetition_penalties += [r] * len(seq_ids)
|
||||||
|
|
||||||
is_prompt = i < sampling_metadata.num_prompts
|
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
prompt_best_of.append(sampling_params.best_of)
|
prompt_best_of.append(sampling_params.best_of)
|
||||||
assert sampling_metadata.prompt_lens is not None
|
subquery_len = seq_group.subquery_len
|
||||||
prompt_len = sampling_metadata.prompt_lens[i]
|
assert subquery_len is not None
|
||||||
|
|
||||||
if sampling_params.prompt_logprobs is not None:
|
|
||||||
# NOTE: the sampling position is the last token
|
|
||||||
# in the prompt
|
|
||||||
sample_indices_start_idx += prompt_len - 1
|
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
seq_data = sampling_metadata.seq_data[seq_id]
|
seq_data = seq_group.seq_data[seq_id]
|
||||||
extra_entropy = extra_entropy or ()
|
extra_entropy = extra_entropy or ()
|
||||||
seq_seeds = cls._get_sequence_seeds(
|
seq_seeds = cls._get_sequence_seeds(
|
||||||
seed,
|
seed,
|
||||||
|
@ -193,8 +411,7 @@ class SamplingTensors:
|
||||||
seeds_to_generate=seeds_to_generate,
|
seeds_to_generate=seeds_to_generate,
|
||||||
is_greedy=is_greedy)
|
is_greedy=is_greedy)
|
||||||
sampling_seeds.append(seq_seeds)
|
sampling_seeds.append(seq_seeds)
|
||||||
sample_indices.append(sample_indices_start_idx)
|
sample_indices.extend(seq_group.sample_indices)
|
||||||
sample_indices_start_idx += 1
|
|
||||||
|
|
||||||
sampling_tensors = SamplingTensors.from_lists(
|
sampling_tensors = SamplingTensors.from_lists(
|
||||||
temperatures, top_ps, top_ks, min_ps, presence_penalties,
|
temperatures, top_ps, top_ks, min_ps, presence_penalties,
|
||||||
|
@ -217,12 +434,14 @@ class SamplingTensors:
|
||||||
# Note that the performance will be very bad without
|
# Note that the performance will be very bad without
|
||||||
# pinned memory.
|
# pinned memory.
|
||||||
pin_memory = is_pin_memory_available()
|
pin_memory = is_pin_memory_available()
|
||||||
prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
|
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
|
||||||
|
default=0)
|
||||||
prompt_padded_tokens = [
|
prompt_padded_tokens = [
|
||||||
tokens + [vocab_size] * (prompt_max_len - len(tokens))
|
tokens + [vocab_size] * (prompt_max_len - len(tokens))
|
||||||
for tokens in prompt_tokens
|
for tokens in prompt_tokens
|
||||||
]
|
]
|
||||||
output_max_len = max(len(tokens) for tokens in output_tokens)
|
output_max_len = max([len(tokens) for tokens in output_tokens],
|
||||||
|
default=0)
|
||||||
output_padded_tokens = [
|
output_padded_tokens = [
|
||||||
tokens + [vocab_size] * (output_max_len - len(tokens))
|
tokens + [vocab_size] * (output_max_len - len(tokens))
|
||||||
for tokens in output_tokens
|
for tokens in output_tokens
|
||||||
|
|
|
@ -28,7 +28,10 @@ class Logprob:
|
||||||
decoded_token: Optional[str] = None
|
decoded_token: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# {token_id -> logprob} per each sequence group. None if the corresponding
|
||||||
|
# sequence group doesn't require prompt logprob.
|
||||||
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
|
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
|
||||||
|
# {token_id -> logprob} for each sequence group.
|
||||||
SampleLogprobs = List[Dict[int, Logprob]]
|
SampleLogprobs = List[Dict[int, Logprob]]
|
||||||
|
|
||||||
|
|
||||||
|
@ -215,7 +218,7 @@ class Sequence:
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
|
|
||||||
self.data = SequenceData(prompt_token_ids)
|
self.data: SequenceData = SequenceData(prompt_token_ids)
|
||||||
self.output_logprobs: SampleLogprobs = []
|
self.output_logprobs: SampleLogprobs = []
|
||||||
self.output_text = ""
|
self.output_text = ""
|
||||||
|
|
||||||
|
@ -559,6 +562,9 @@ class SequenceGroupMetadata:
|
||||||
sampling_params: The sampling parameters used to generate the outputs.
|
sampling_params: The sampling parameters used to generate the outputs.
|
||||||
block_tables: The block tables. (Seq id -> list of physical block
|
block_tables: The block tables. (Seq id -> list of physical block
|
||||||
numbers)
|
numbers)
|
||||||
|
do_sample: True if sampling is required. Sampling is not required when
|
||||||
|
e.g., prefill is chunked, and the current iteration only computes
|
||||||
|
query tokens for prefill, we don't need sampling.
|
||||||
token_chunk_size: The number of tokens to be processed (per sequence).
|
token_chunk_size: The number of tokens to be processed (per sequence).
|
||||||
None if chunking is not required.
|
None if chunking is not required.
|
||||||
state: Internal state tied to this sequence group.
|
state: Internal state tied to this sequence group.
|
||||||
|
@ -573,6 +579,7 @@ class SequenceGroupMetadata:
|
||||||
seq_data: Dict[int, SequenceData],
|
seq_data: Dict[int, SequenceData],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
block_tables: Dict[int, List[int]],
|
block_tables: Dict[int, List[int]],
|
||||||
|
do_sample: bool = True,
|
||||||
token_chunk_size: Optional[int] = None,
|
token_chunk_size: Optional[int] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
computed_block_nums: Optional[List[int]] = None,
|
computed_block_nums: Optional[List[int]] = None,
|
||||||
|
@ -589,6 +596,7 @@ class SequenceGroupMetadata:
|
||||||
self.multi_modal_data = multi_modal_data
|
self.multi_modal_data = multi_modal_data
|
||||||
self.state = SequenceGroupState() if state is None else state
|
self.state = SequenceGroupState() if state is None else state
|
||||||
self._token_chunk_size = token_chunk_size
|
self._token_chunk_size = token_chunk_size
|
||||||
|
self.do_sample = do_sample
|
||||||
|
|
||||||
if self._token_chunk_size is None:
|
if self._token_chunk_size is None:
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
|
@ -650,6 +658,7 @@ class SequenceGroupOutput:
|
||||||
prompt_logprobs: Optional[PromptLogprobs],
|
prompt_logprobs: Optional[PromptLogprobs],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.samples = samples
|
self.samples = samples
|
||||||
|
# Prompt logprob for each prompt query token.
|
||||||
self.prompt_logprobs = prompt_logprobs
|
self.prompt_logprobs = prompt_logprobs
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -10,9 +10,8 @@ from vllm.distributed import broadcast_tensor_dict
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.utils import make_tensor_with_pad
|
||||||
from vllm.utils import make_tensor_with_pad, maybe_expand_dim
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
@ -38,6 +37,8 @@ class CPUModelRunner:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
|
# Currently, CPU worker doesn't support chunked prefill.
|
||||||
|
assert self.scheduler_config.chunked_prefill_enabled is False
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
self.vision_language_config = vision_language_config
|
self.vision_language_config = vision_language_config
|
||||||
self.load_config = load_config
|
self.load_config = load_config
|
||||||
|
@ -252,99 +253,6 @@ class CPUModelRunner:
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prepare_sample(
|
|
||||||
self,
|
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
||||||
prompt_lens: List[int],
|
|
||||||
) -> SamplingMetadata:
|
|
||||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
|
||||||
selected_token_indices: List[int] = []
|
|
||||||
generators: List[torch.Generator] = []
|
|
||||||
selected_token_start_idx = 0
|
|
||||||
categorized_sample_indices: Dict[SamplingType,
|
|
||||||
List[Tuple[int, int]]] = {
|
|
||||||
t: []
|
|
||||||
for t in SamplingType
|
|
||||||
}
|
|
||||||
categorized_sample_indices_start_idx = 0
|
|
||||||
categorized_sampled_token_indices_start_idx = 0
|
|
||||||
|
|
||||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
|
||||||
seq_groups.append((seq_ids, sampling_params))
|
|
||||||
|
|
||||||
if seq_group_metadata.is_prompt:
|
|
||||||
assert len(seq_ids) == 1
|
|
||||||
subquery_len = prompt_lens[i]
|
|
||||||
if sampling_params.prompt_logprobs is not None:
|
|
||||||
# NOTE: prompt token positions do not need sample, skip
|
|
||||||
categorized_sample_indices_start_idx += subquery_len - 1
|
|
||||||
|
|
||||||
categorized_sample_indices[
|
|
||||||
sampling_params.sampling_type].append(
|
|
||||||
(categorized_sample_indices_start_idx,
|
|
||||||
categorized_sampled_token_indices_start_idx))
|
|
||||||
categorized_sample_indices_start_idx += 1
|
|
||||||
categorized_sampled_token_indices_start_idx += 1
|
|
||||||
|
|
||||||
if sampling_params.prompt_logprobs is not None:
|
|
||||||
selected_token_indices.extend(
|
|
||||||
range(selected_token_start_idx,
|
|
||||||
selected_token_start_idx + subquery_len - 1))
|
|
||||||
selected_token_indices.append(selected_token_start_idx +
|
|
||||||
subquery_len - 1)
|
|
||||||
selected_token_start_idx += subquery_len
|
|
||||||
|
|
||||||
if sampling_params.seed is not None:
|
|
||||||
seq_group_metadata.state.generator = torch.Generator(
|
|
||||||
device=self.device).manual_seed(sampling_params.seed)
|
|
||||||
else:
|
|
||||||
num_seqs = len(seq_ids)
|
|
||||||
selected_token_indices.extend(
|
|
||||||
range(selected_token_start_idx,
|
|
||||||
selected_token_start_idx + num_seqs))
|
|
||||||
selected_token_start_idx += num_seqs
|
|
||||||
|
|
||||||
categorized_sample_indices[
|
|
||||||
sampling_params.sampling_type].extend(
|
|
||||||
zip(
|
|
||||||
range(
|
|
||||||
categorized_sample_indices_start_idx,
|
|
||||||
categorized_sample_indices_start_idx +
|
|
||||||
num_seqs),
|
|
||||||
range(
|
|
||||||
categorized_sampled_token_indices_start_idx,
|
|
||||||
categorized_sampled_token_indices_start_idx +
|
|
||||||
num_seqs)))
|
|
||||||
categorized_sample_indices_start_idx += num_seqs
|
|
||||||
categorized_sampled_token_indices_start_idx += num_seqs
|
|
||||||
|
|
||||||
if sampling_params.seed is not None:
|
|
||||||
generators.append(seq_group_metadata.state.generator)
|
|
||||||
|
|
||||||
selected_token_indices = torch.tensor(selected_token_indices,
|
|
||||||
dtype=torch.long)
|
|
||||||
|
|
||||||
categorized_sample_indices = {
|
|
||||||
t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2)
|
|
||||||
for t, seq_ids in categorized_sample_indices.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
seq_data: Dict[int, SequenceData] = {}
|
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
|
||||||
seq_data.update(seq_group_metadata.seq_data)
|
|
||||||
|
|
||||||
sampling_metadata = SamplingMetadata(
|
|
||||||
seq_groups=seq_groups,
|
|
||||||
seq_data=seq_data,
|
|
||||||
prompt_lens=prompt_lens,
|
|
||||||
selected_token_indices=selected_token_indices,
|
|
||||||
categorized_sample_indices=categorized_sample_indices,
|
|
||||||
generators=generators,
|
|
||||||
)
|
|
||||||
return sampling_metadata
|
|
||||||
|
|
||||||
def prepare_input_tensors(
|
def prepare_input_tensors(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
@ -364,8 +272,15 @@ class CPUModelRunner:
|
||||||
(input_tokens, input_positions,
|
(input_tokens, input_positions,
|
||||||
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
|
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
|
||||||
prompt_lens = []
|
prompt_lens = []
|
||||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
prompt_lens)
|
seq_group_metadata_list,
|
||||||
|
prompt_lens,
|
||||||
|
# subquery_lens is not needed if chunked prefill is not
|
||||||
|
# supported. Since CPU worker doesn't support chunked prefill
|
||||||
|
# just use prompt_lens instead.
|
||||||
|
prompt_lens,
|
||||||
|
self.device,
|
||||||
|
pin_memory=False)
|
||||||
# Broadcast the metadata.
|
# Broadcast the metadata.
|
||||||
metadata_dict = {
|
metadata_dict = {
|
||||||
"input_tokens": input_tokens,
|
"input_tokens": input_tokens,
|
||||||
|
@ -389,7 +304,6 @@ class CPUModelRunner:
|
||||||
selected_token_indices=selected_token_indices,
|
selected_token_indices=selected_token_indices,
|
||||||
categorized_sample_indices=None,
|
categorized_sample_indices=None,
|
||||||
generators=None,
|
generators=None,
|
||||||
perform_sampling=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return (input_tokens, input_positions, attn_metadata,
|
return (input_tokens, input_positions, attn_metadata,
|
||||||
|
@ -421,7 +335,7 @@ class CPUModelRunner:
|
||||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||||
|
|
||||||
# Only perform sampling in the driver worker.
|
# Only perform sampling in the driver worker.
|
||||||
if not sampling_metadata.perform_sampling:
|
if not self.is_driver_worker:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
|
|
|
@ -20,12 +20,11 @@ from vllm.lora.request import LoRARequest
|
||||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
|
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
|
||||||
SequenceGroupMetadata)
|
SequenceGroupMetadata)
|
||||||
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
|
from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available,
|
||||||
is_pin_memory_available, make_tensor_with_pad,
|
make_tensor_with_pad)
|
||||||
maybe_expand_dim)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
@ -547,108 +546,6 @@ class ModelRunner:
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prepare_sample(
|
|
||||||
self,
|
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
||||||
prompt_lens: List[int],
|
|
||||||
subquery_lens: Optional[List[int]],
|
|
||||||
) -> SamplingMetadata:
|
|
||||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
|
||||||
selected_token_indices: List[int] = []
|
|
||||||
generators: List[torch.Generator] = []
|
|
||||||
selected_token_start_idx = 0
|
|
||||||
categorized_sample_indices: Dict[SamplingType,
|
|
||||||
List[Tuple[int, int]]] = {
|
|
||||||
t: []
|
|
||||||
for t in SamplingType
|
|
||||||
}
|
|
||||||
categorized_sample_indices_start_idx = 0
|
|
||||||
categorized_sampled_token_indices_start_idx = 0
|
|
||||||
|
|
||||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
|
||||||
seq_groups.append((seq_ids, sampling_params))
|
|
||||||
|
|
||||||
if seq_group_metadata.is_prompt:
|
|
||||||
assert len(seq_ids) == 1
|
|
||||||
assert subquery_lens is not None
|
|
||||||
subquery_len = subquery_lens[i]
|
|
||||||
if sampling_params.prompt_logprobs is not None:
|
|
||||||
# NOTE: prompt token positions do not need sample, skip
|
|
||||||
categorized_sample_indices_start_idx += subquery_len - 1
|
|
||||||
|
|
||||||
categorized_sample_indices[
|
|
||||||
sampling_params.sampling_type].append(
|
|
||||||
(categorized_sample_indices_start_idx,
|
|
||||||
categorized_sampled_token_indices_start_idx))
|
|
||||||
categorized_sample_indices_start_idx += 1
|
|
||||||
categorized_sampled_token_indices_start_idx += 1
|
|
||||||
|
|
||||||
if sampling_params.prompt_logprobs is not None:
|
|
||||||
selected_token_indices.extend(
|
|
||||||
range(selected_token_start_idx,
|
|
||||||
selected_token_start_idx + subquery_len - 1))
|
|
||||||
selected_token_indices.append(selected_token_start_idx +
|
|
||||||
subquery_len - 1)
|
|
||||||
selected_token_start_idx += subquery_len
|
|
||||||
|
|
||||||
if sampling_params.seed is not None:
|
|
||||||
seq_group_metadata.state.generator = torch.Generator(
|
|
||||||
device=self.device).manual_seed(sampling_params.seed)
|
|
||||||
else:
|
|
||||||
num_seqs = len(seq_ids)
|
|
||||||
selected_token_indices.extend(
|
|
||||||
range(selected_token_start_idx,
|
|
||||||
selected_token_start_idx + num_seqs))
|
|
||||||
selected_token_start_idx += num_seqs
|
|
||||||
|
|
||||||
categorized_sample_indices[
|
|
||||||
sampling_params.sampling_type].extend(
|
|
||||||
list(
|
|
||||||
zip(
|
|
||||||
range(
|
|
||||||
categorized_sample_indices_start_idx,
|
|
||||||
categorized_sample_indices_start_idx +
|
|
||||||
num_seqs),
|
|
||||||
range(
|
|
||||||
categorized_sampled_token_indices_start_idx,
|
|
||||||
categorized_sampled_token_indices_start_idx
|
|
||||||
+ num_seqs))))
|
|
||||||
categorized_sample_indices_start_idx += num_seqs
|
|
||||||
categorized_sampled_token_indices_start_idx += num_seqs
|
|
||||||
|
|
||||||
if sampling_params.seed is not None:
|
|
||||||
generators.append(seq_group_metadata.state.generator)
|
|
||||||
|
|
||||||
selected_token_indices = async_tensor_h2d(selected_token_indices,
|
|
||||||
dtype=torch.long,
|
|
||||||
target_device=self.device,
|
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
|
|
||||||
categorized_sample_indices = {
|
|
||||||
t: maybe_expand_dim(
|
|
||||||
async_tensor_h2d(seq_ids,
|
|
||||||
dtype=torch.int,
|
|
||||||
target_device=self.device,
|
|
||||||
pin_memory=self.pin_memory), 2, 2)
|
|
||||||
for t, seq_ids in categorized_sample_indices.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
seq_data: Dict[int, SequenceData] = {}
|
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
|
||||||
seq_data.update(seq_group_metadata.seq_data)
|
|
||||||
|
|
||||||
sampling_metadata = SamplingMetadata(
|
|
||||||
seq_groups=seq_groups,
|
|
||||||
seq_data=seq_data,
|
|
||||||
prompt_lens=prompt_lens,
|
|
||||||
selected_token_indices=selected_token_indices,
|
|
||||||
categorized_sample_indices=categorized_sample_indices,
|
|
||||||
generators=generators,
|
|
||||||
)
|
|
||||||
return sampling_metadata
|
|
||||||
|
|
||||||
def prepare_input_tensors(
|
def prepare_input_tensors(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
@ -685,9 +582,9 @@ class ModelRunner:
|
||||||
decode_lora_requests,
|
decode_lora_requests,
|
||||||
decode_slot_mapping,
|
decode_slot_mapping,
|
||||||
) = self._prepare_decode(decode_reqs)
|
) = self._prepare_decode(decode_reqs)
|
||||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
prompt_lens,
|
seq_group_metadata_list, prompt_lens, subquery_lens,
|
||||||
subquery_lens)
|
self.device, self.pin_memory)
|
||||||
|
|
||||||
if not self.scheduler_config.chunked_prefill_enabled:
|
if not self.scheduler_config.chunked_prefill_enabled:
|
||||||
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
||||||
|
@ -788,12 +685,9 @@ class ModelRunner:
|
||||||
**metadata_dict)
|
**metadata_dict)
|
||||||
sampling_metadata = SamplingMetadata(
|
sampling_metadata = SamplingMetadata(
|
||||||
seq_groups=None,
|
seq_groups=None,
|
||||||
seq_data=None,
|
|
||||||
prompt_lens=None,
|
|
||||||
selected_token_indices=selected_token_indices,
|
selected_token_indices=selected_token_indices,
|
||||||
categorized_sample_indices=None,
|
categorized_sample_indices=None,
|
||||||
generators=None,
|
num_prompts=0,
|
||||||
perform_sampling=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# if it is a mixed batch, decode attn_metadata is broadcasted
|
# if it is a mixed batch, decode attn_metadata is broadcasted
|
||||||
|
@ -852,7 +746,7 @@ class ModelRunner:
|
||||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||||
|
|
||||||
# Only perform sampling in the driver worker.
|
# Only perform sampling in the driver worker.
|
||||||
if not sampling_metadata.perform_sampling:
|
if not self.is_driver_worker:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
|
@ -860,6 +754,7 @@ class ModelRunner:
|
||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -8,10 +8,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.model_loader.neuron import get_neuron_model
|
from vllm.model_executor.model_loader.neuron import get_neuron_model
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||||
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
|
|
||||||
make_tensor_with_pad, maybe_expand_dim)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
@ -141,106 +139,6 @@ class NeuronModelRunner:
|
||||||
|
|
||||||
return input_tokens, input_positions, input_block_ids
|
return input_tokens, input_positions, input_block_ids
|
||||||
|
|
||||||
def _prepare_sample(
|
|
||||||
self,
|
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
||||||
prompt_lens: List[int],
|
|
||||||
) -> SamplingMetadata:
|
|
||||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
|
||||||
selected_token_indices: List[int] = []
|
|
||||||
generators: List[torch.Generator] = []
|
|
||||||
selected_token_start_idx = 0
|
|
||||||
categorized_sample_indices: Dict[SamplingType,
|
|
||||||
List[Tuple[int, int]]] = {
|
|
||||||
t: []
|
|
||||||
for t in SamplingType
|
|
||||||
}
|
|
||||||
categorized_sample_indices_start_idx = 0
|
|
||||||
categorized_sampled_token_indices_start_idx = 0
|
|
||||||
|
|
||||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
|
||||||
seq_groups.append((seq_ids, sampling_params))
|
|
||||||
|
|
||||||
if seq_group_metadata.is_prompt:
|
|
||||||
assert len(seq_ids) == 1
|
|
||||||
assert prompt_lens is not None
|
|
||||||
prompt_len = prompt_lens[i]
|
|
||||||
if sampling_params.prompt_logprobs is not None:
|
|
||||||
# NOTE: prompt token positions do not need sample, skip
|
|
||||||
categorized_sample_indices_start_idx += prompt_len - 1
|
|
||||||
|
|
||||||
categorized_sample_indices[
|
|
||||||
sampling_params.sampling_type].append(
|
|
||||||
(categorized_sample_indices_start_idx,
|
|
||||||
categorized_sampled_token_indices_start_idx))
|
|
||||||
categorized_sample_indices_start_idx += 1
|
|
||||||
categorized_sampled_token_indices_start_idx += 1
|
|
||||||
|
|
||||||
if sampling_params.prompt_logprobs is not None:
|
|
||||||
selected_token_indices.extend(
|
|
||||||
range(selected_token_start_idx,
|
|
||||||
selected_token_start_idx + prompt_len - 1))
|
|
||||||
selected_token_indices.append(selected_token_start_idx +
|
|
||||||
prompt_len - 1)
|
|
||||||
selected_token_start_idx += prompt_len
|
|
||||||
|
|
||||||
if sampling_params.seed is not None:
|
|
||||||
seq_group_metadata.state.generator = torch.Generator(
|
|
||||||
device=self.device).manual_seed(sampling_params.seed)
|
|
||||||
else:
|
|
||||||
num_seqs = len(seq_ids)
|
|
||||||
selected_token_indices.extend(
|
|
||||||
range(selected_token_start_idx,
|
|
||||||
selected_token_start_idx + num_seqs))
|
|
||||||
selected_token_start_idx += num_seqs
|
|
||||||
|
|
||||||
categorized_sample_indices[
|
|
||||||
sampling_params.sampling_type].extend(
|
|
||||||
zip(
|
|
||||||
range(
|
|
||||||
categorized_sample_indices_start_idx,
|
|
||||||
categorized_sample_indices_start_idx +
|
|
||||||
num_seqs),
|
|
||||||
range(
|
|
||||||
categorized_sampled_token_indices_start_idx,
|
|
||||||
categorized_sampled_token_indices_start_idx +
|
|
||||||
num_seqs)))
|
|
||||||
categorized_sample_indices_start_idx += num_seqs
|
|
||||||
categorized_sampled_token_indices_start_idx += num_seqs
|
|
||||||
|
|
||||||
if sampling_params.seed is not None:
|
|
||||||
generators.append(seq_group_metadata.state.generator)
|
|
||||||
|
|
||||||
selected_token_indices = async_tensor_h2d(selected_token_indices,
|
|
||||||
dtype=torch.long,
|
|
||||||
target_device=self.device,
|
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
|
|
||||||
categorized_sample_indices = {
|
|
||||||
t: maybe_expand_dim(
|
|
||||||
async_tensor_h2d(seq_ids,
|
|
||||||
dtype=torch.int,
|
|
||||||
target_device=self.device,
|
|
||||||
pin_memory=self.pin_memory), 2, 2)
|
|
||||||
for t, seq_ids in categorized_sample_indices.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
seq_data: Dict[int, SequenceData] = {}
|
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
|
||||||
seq_data.update(seq_group_metadata.seq_data)
|
|
||||||
|
|
||||||
sampling_metadata = SamplingMetadata(
|
|
||||||
seq_groups=seq_groups,
|
|
||||||
seq_data=seq_data,
|
|
||||||
prompt_lens=prompt_lens,
|
|
||||||
selected_token_indices=selected_token_indices,
|
|
||||||
categorized_sample_indices=categorized_sample_indices,
|
|
||||||
generators=generators,
|
|
||||||
)
|
|
||||||
return sampling_metadata
|
|
||||||
|
|
||||||
def prepare_input_tensors(
|
def prepare_input_tensors(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
@ -256,8 +154,15 @@ class NeuronModelRunner:
|
||||||
(input_tokens, input_positions,
|
(input_tokens, input_positions,
|
||||||
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
|
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
|
||||||
prompt_lens = []
|
prompt_lens = []
|
||||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
prompt_lens)
|
seq_group_metadata_list,
|
||||||
|
prompt_lens,
|
||||||
|
# subquery_lens is not needed if chunked prefill is not
|
||||||
|
# supported. Since neuron worker doesn't support chunked prefill
|
||||||
|
# just use prompt_lens instead.
|
||||||
|
prompt_lens,
|
||||||
|
self.device,
|
||||||
|
self.pin_memory)
|
||||||
|
|
||||||
return (input_tokens, input_positions, input_block_ids,
|
return (input_tokens, input_positions, input_block_ids,
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
|
|
Loading…
Reference in New Issue