mirror of https://github.com/vllm-project/vllm
Aligning `top_p` and `top_k` Sampling (#1885)
* Align top_p and top_k with huggingface * remove _get_prompt_and_output_tokens * rename _apply_top_p_top_k * compare top_p top_k with hf * fix test errors
This commit is contained in:
parent
827cbcd37c
commit
218dc2ccda
|
@ -4,6 +4,7 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import GenerationConfig, GenerationMixin
|
||||
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
@ -233,3 +234,65 @@ def test_sampler_logits_processors(seed: int):
|
|||
for _, sequence_output in enumerate(sampler_output):
|
||||
for idx, nth_output in enumerate(sequence_output.samples):
|
||||
assert nth_output.output_token == idx
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_top_k_top_p(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
top_k = random.randint(100, 500)
|
||||
top_p = random.random() * 0.1
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024),
|
||||
device="cuda",
|
||||
dtype=torch.float16)
|
||||
fake_logits = torch.normal(0,
|
||||
5,
|
||||
size=(batch_size, vocab_size),
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(32000, fake_logits)
|
||||
model_runner = ModelRunner(None, None, None)
|
||||
|
||||
generation_model = GenerationMixin()
|
||||
generation_config = GenerationConfig(top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=True)
|
||||
warpers = generation_model._get_logits_warper(generation_config)
|
||||
assert len(warpers) == 2 # top_p and top_k
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=1,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
|
||||
sample_probs = None
|
||||
|
||||
def mock_sample(probs, logprobs, sampling_metadata):
|
||||
nonlocal sample_probs
|
||||
sample_probs = probs
|
||||
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
|
||||
|
||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
|
||||
sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
|
||||
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
||||
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
|
||||
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
||||
|
|
|
@ -76,7 +76,7 @@ class Sampler(nn.Module):
|
|||
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
|
||||
|
||||
if do_top_p_top_k:
|
||||
logits = _apply_top_p_top_k(logits, sampling_tensors.top_ps,
|
||||
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
||||
sampling_tensors.top_ks)
|
||||
|
||||
if do_min_p:
|
||||
|
@ -185,27 +185,27 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
|||
return logits
|
||||
|
||||
|
||||
def _apply_top_p_top_k(
|
||||
def _apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
|
||||
top_p_mask = probs_sum > p.unsqueeze_(dim=1)
|
||||
|
||||
# Apply top-k.
|
||||
# Create a mask for the top-k elements.
|
||||
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
|
||||
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
|
||||
top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
|
||||
|
||||
# Final mask.
|
||||
mask = (top_p_mask | top_k_mask)
|
||||
logits_sort.masked_fill_(mask, -float("inf"))
|
||||
probs_sum = probs_sort.cumsum(dim=-1)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# Re-sort the probabilities.
|
||||
src = torch.arange(logits_idx.shape[-1],
|
||||
|
|
Loading…
Reference in New Issue