mirror of https://github.com/vllm-project/vllm
Align with huggingface Top K sampling (#753)
This commit is contained in:
parent
805de738f6
commit
d1744376ae
|
@ -71,20 +71,20 @@ class Sampler(nn.Module):
|
|||
# Use in-place division to avoid creating a new tensor.
|
||||
logits.div_(t.unsqueeze(dim=1))
|
||||
|
||||
# Apply top-p and top-k truncation.
|
||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
||||
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
||||
do_top_k = any(k != self.vocab_size for k in top_ks)
|
||||
if do_top_p or do_top_k:
|
||||
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
|
||||
|
||||
# We use float32 for probabilities and log probabilities.
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities (before applying top-p and top-k).
|
||||
logprobs = torch.log(probs)
|
||||
|
||||
# Apply top-p and top-k truncation.
|
||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
||||
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
||||
do_top_k = any(k != self.vocab_size for k in top_ks)
|
||||
if do_top_p or do_top_k:
|
||||
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
||||
|
||||
# Sample the next tokens.
|
||||
return _sample(probs, logprobs, input_metadata)
|
||||
|
||||
|
@ -235,31 +235,32 @@ def _get_top_p_top_k(
|
|||
|
||||
|
||||
def _apply_top_p_top_k(
|
||||
probs: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
top_ps: List[float],
|
||||
top_ks: List[int],
|
||||
) -> torch.Tensor:
|
||||
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
||||
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
|
||||
k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
|
||||
|
||||
# Apply top-p.
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = probs_sort.cumsum(dim=-1)
|
||||
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
||||
probs_sort[top_p_mask] = 0.0
|
||||
logits_sort[top_p_mask] = -float("inf")
|
||||
|
||||
# Apply top-k.
|
||||
# Create a mask for the top-k elements.
|
||||
top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
|
||||
top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1)
|
||||
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
|
||||
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
|
||||
top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
|
||||
probs_sort[top_k_mask] = 0.0
|
||||
logits_sort[top_k_mask] = -float("inf")
|
||||
|
||||
# Re-sort the probabilities.
|
||||
probs = torch.gather(probs_sort,
|
||||
logits = torch.gather(logits_sort,
|
||||
dim=-1,
|
||||
index=torch.argsort(probs_idx, dim=-1))
|
||||
return probs
|
||||
index=torch.argsort(logits_idx, dim=-1))
|
||||
return logits
|
||||
|
||||
|
||||
def _get_topk_logprobs(
|
||||
|
@ -301,9 +302,7 @@ def _sample_from_prompt(
|
|||
# Random sampling.
|
||||
# Sample `best_of` tokens for the prompt.
|
||||
num_seqs = sampling_params.best_of
|
||||
next_token_ids = torch.multinomial(prob,
|
||||
num_samples=num_seqs,
|
||||
replacement=True)
|
||||
next_token_ids = torch.multinomial(prob, num_samples=num_seqs)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
return next_token_ids
|
||||
|
||||
|
|
Loading…
Reference in New Issue