Align with huggingface Top K sampling (#753)

This commit is contained in:
Abraham-Xu 2023-08-16 07:44:33 +08:00 committed by GitHub
parent 805de738f6
commit d1744376ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 24 deletions

View File

@ -71,20 +71,20 @@ class Sampler(nn.Module):
# Use in-place division to avoid creating a new tensor. # Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1)) logits.div_(t.unsqueeze(dim=1))
# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == logits.shape[0]
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k:
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
# We use float32 for probabilities and log probabilities. # We use float32 for probabilities and log probabilities.
# Compute the probabilities. # Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p and top-k). # Compute the log probabilities (before applying top-p and top-k).
logprobs = torch.log(probs) logprobs = torch.log(probs)
# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == probs.shape[0]
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k:
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
# Sample the next tokens. # Sample the next tokens.
return _sample(probs, logprobs, input_metadata) return _sample(probs, logprobs, input_metadata)
@ -235,31 +235,32 @@ def _get_top_p_top_k(
def _apply_top_p_top_k( def _apply_top_p_top_k(
probs: torch.Tensor, logits: torch.Tensor,
top_ps: List[float], top_ps: List[float],
top_ks: List[int], top_ks: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device) p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device) k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
probs_sort, probs_idx = probs.sort(dim=-1, descending=True) logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
# Apply top-p. # Apply top-p.
probs_sum = torch.cumsum(probs_sort, dim=-1) probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
probs_sort[top_p_mask] = 0.0 logits_sort[top_p_mask] = -float("inf")
# Apply top-k. # Apply top-k.
# Create a mask for the top-k elements. # Create a mask for the top-k elements.
top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1) top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
top_k_mask = top_k_mask >= k.unsqueeze(dim=1) top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
probs_sort[top_k_mask] = 0.0 logits_sort[top_k_mask] = -float("inf")
# Re-sort the probabilities. # Re-sort the probabilities.
probs = torch.gather(probs_sort, logits = torch.gather(logits_sort,
dim=-1, dim=-1,
index=torch.argsort(probs_idx, dim=-1)) index=torch.argsort(logits_idx, dim=-1))
return probs return logits
def _get_topk_logprobs( def _get_topk_logprobs(
@ -301,9 +302,7 @@ def _sample_from_prompt(
# Random sampling. # Random sampling.
# Sample `best_of` tokens for the prompt. # Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial(prob, next_token_ids = torch.multinomial(prob, num_samples=num_seqs)
num_samples=num_seqs,
replacement=True)
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
return next_token_ids return next_token_ids