Aligning `top_p` and `top_k` Sampling (#1885)

* Align top_p and top_k with huggingface

* remove _get_prompt_and_output_tokens

* rename _apply_top_p_top_k

* compare top_p top_k with hf

* fix test errors
This commit is contained in:
陈序 2024-01-13 05:51:03 +08:00 committed by GitHub
parent 827cbcd37c
commit 218dc2ccda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 78 additions and 15 deletions

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
import torch
from transformers import GenerationConfig, GenerationMixin
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
@ -233,3 +234,65 @@ def test_sampler_logits_processors(seed: int):
for _, sequence_output in enumerate(sampler_output):
for idx, nth_output in enumerate(sequence_output.samples):
assert nth_output.output_token == idx
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_top_k_top_p(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
top_k = random.randint(100, 500)
top_p = random.random() * 0.1
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024),
device="cuda",
dtype=torch.float16)
fake_logits = torch.normal(0,
5,
size=(batch_size, vocab_size),
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
model_runner = ModelRunner(None, None, None)
generation_model = GenerationMixin()
generation_config = GenerationConfig(top_k=top_k,
top_p=top_p,
do_sample=True)
warpers = generation_model._get_logits_warper(generation_config)
assert len(warpers) == 2 # top_p and top_k
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1,
top_k=top_k,
top_p=top_p,
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
sample_probs = None
def mock_sample(probs, logprobs, sampling_metadata):
nonlocal sample_probs
sample_probs = probs
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))

View File

@ -76,7 +76,7 @@ class Sampler(nn.Module):
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
if do_top_p_top_k:
logits = _apply_top_p_top_k(logits, sampling_tensors.top_ps,
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
if do_min_p:
@ -185,27 +185,27 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
return logits
def _apply_top_p_top_k(
def _apply_top_k_top_p(
logits: torch.Tensor,
p: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
top_p_mask = probs_sum > p.unsqueeze_(dim=1)
# Apply top-k.
# Create a mask for the top-k elements.
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
# Final mask.
mask = (top_p_mask | top_k_mask)
logits_sort.masked_fill_(mask, -float("inf"))
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
src = torch.arange(logits_idx.shape[-1],