mirror of https://github.com/vllm-project/vllm
Aligning `top_p` and `top_k` Sampling (#1885)
* Align top_p and top_k with huggingface * remove _get_prompt_and_output_tokens * rename _apply_top_p_top_k * compare top_p top_k with hf * fix test errors
This commit is contained in:
parent
827cbcd37c
commit
218dc2ccda
|
@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import GenerationConfig, GenerationMixin
|
||||||
|
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
@ -233,3 +234,65 @@ def test_sampler_logits_processors(seed: int):
|
||||||
for _, sequence_output in enumerate(sampler_output):
|
for _, sequence_output in enumerate(sampler_output):
|
||||||
for idx, nth_output in enumerate(sequence_output.samples):
|
for idx, nth_output in enumerate(sequence_output.samples):
|
||||||
assert nth_output.output_token == idx
|
assert nth_output.output_token == idx
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
def test_sampler_top_k_top_p(seed: int):
|
||||||
|
set_random_seed(seed)
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
top_k = random.randint(100, 500)
|
||||||
|
top_p = random.random() * 0.1
|
||||||
|
vocab_size = 32000
|
||||||
|
input_tensor = torch.rand((batch_size, 1024),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float16)
|
||||||
|
fake_logits = torch.normal(0,
|
||||||
|
5,
|
||||||
|
size=(batch_size, vocab_size),
|
||||||
|
device=input_tensor.device,
|
||||||
|
dtype=input_tensor.dtype)
|
||||||
|
sampler = MockLogitsSampler(32000, fake_logits)
|
||||||
|
model_runner = ModelRunner(None, None, None)
|
||||||
|
|
||||||
|
generation_model = GenerationMixin()
|
||||||
|
generation_config = GenerationConfig(top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
do_sample=True)
|
||||||
|
warpers = generation_model._get_logits_warper(generation_config)
|
||||||
|
assert len(warpers) == 2 # top_p and top_k
|
||||||
|
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
prompt_lens = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData([1, 2, 3])},
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
temperature=1,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
),
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
|
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||||
|
prompt_lens)
|
||||||
|
|
||||||
|
sample_probs = None
|
||||||
|
|
||||||
|
def mock_sample(probs, logprobs, sampling_metadata):
|
||||||
|
nonlocal sample_probs
|
||||||
|
sample_probs = probs
|
||||||
|
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
|
||||||
|
|
||||||
|
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
|
||||||
|
sampler(embedding=None,
|
||||||
|
hidden_states=input_tensor,
|
||||||
|
sampling_metadata=sampling_metadata)
|
||||||
|
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
|
||||||
|
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
||||||
|
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
|
||||||
|
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
||||||
|
|
|
@ -76,7 +76,7 @@ class Sampler(nn.Module):
|
||||||
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
|
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
|
||||||
|
|
||||||
if do_top_p_top_k:
|
if do_top_p_top_k:
|
||||||
logits = _apply_top_p_top_k(logits, sampling_tensors.top_ps,
|
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
||||||
sampling_tensors.top_ks)
|
sampling_tensors.top_ks)
|
||||||
|
|
||||||
if do_min_p:
|
if do_min_p:
|
||||||
|
@ -185,27 +185,27 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _apply_top_p_top_k(
|
def _apply_top_k_top_p(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
p: torch.Tensor,
|
p: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||||
|
|
||||||
|
# Apply top-k.
|
||||||
|
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
||||||
|
# Get all the top_k values.
|
||||||
|
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||||
|
top_k_mask = logits_sort < top_k_mask
|
||||||
|
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||||
|
|
||||||
# Apply top-p.
|
# Apply top-p.
|
||||||
probs_sort = logits_sort.softmax(dim=-1)
|
probs_sort = logits_sort.softmax(dim=-1)
|
||||||
probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
|
probs_sum = probs_sort.cumsum(dim=-1)
|
||||||
top_p_mask = probs_sum > p.unsqueeze_(dim=1)
|
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||||
|
# at least one
|
||||||
# Apply top-k.
|
top_p_mask[:, -1] = False
|
||||||
# Create a mask for the top-k elements.
|
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||||
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
|
|
||||||
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
|
|
||||||
top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
|
|
||||||
|
|
||||||
# Final mask.
|
|
||||||
mask = (top_p_mask | top_k_mask)
|
|
||||||
logits_sort.masked_fill_(mask, -float("inf"))
|
|
||||||
|
|
||||||
# Re-sort the probabilities.
|
# Re-sort the probabilities.
|
||||||
src = torch.arange(logits_idx.shape[-1],
|
src = torch.arange(logits_idx.shape[-1],
|
||||||
|
|
Loading…
Reference in New Issue