mirror of https://github.com/vllm-project/vllm
Minor code cleaning for SamplingParams (#99)
This commit is contained in:
parent
42f1042e1c
commit
6208d622ca
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, Set
|
||||
from typing import Set
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
|
@ -16,54 +16,6 @@ class SamplingParams:
|
|||
max_tokens: int = 16,
|
||||
logprobs: int = 0,
|
||||
) -> None:
|
||||
if n < 1:
|
||||
raise ValueError(f"n must be at least 1, got {n}.")
|
||||
if not -2.0 <= presence_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
f"presence_penalty must be in [-2, 2], got {presence_penalty}.")
|
||||
if not -2.0 <= frequency_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
f"frequency_penalty must be in [-2, 2], got {frequency_penalty}.")
|
||||
if temperature < 0.0:
|
||||
raise ValueError(
|
||||
f"temperature must be non-negative, got {temperature}.")
|
||||
if not 0.0 < top_p <= 1.0:
|
||||
raise ValueError(f"top_p must be in (0, 1], got {top_p}.")
|
||||
if top_k < -1 or top_k == 0:
|
||||
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
|
||||
f"got {top_k}.")
|
||||
if max_tokens < 1:
|
||||
raise ValueError(
|
||||
f"max_tokens must be at least 1, got {max_tokens}.")
|
||||
if logprobs < 0:
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {logprobs}.")
|
||||
|
||||
if use_beam_search:
|
||||
if n == 1:
|
||||
raise ValueError(
|
||||
"n must be greater than 1 when using beam search.")
|
||||
if temperature > 0.0:
|
||||
raise ValueError(
|
||||
"temperature must be 0 when using beam search.")
|
||||
if top_p < 1.0:
|
||||
raise ValueError(
|
||||
"top_p must be 1 when using beam search.")
|
||||
if top_k != -1:
|
||||
raise ValueError(
|
||||
"top_k must be -1 when using beam search.")
|
||||
elif temperature == 0.0:
|
||||
# Zero temperature means greedy sampling.
|
||||
if n > 1:
|
||||
raise ValueError(
|
||||
"n must be 1 when using greedy sampling.")
|
||||
if top_p < 1.0:
|
||||
raise ValueError(
|
||||
"top_p must be 1 when using greedy sampling.")
|
||||
if top_k != -1:
|
||||
raise ValueError(
|
||||
"top_k must be -1 when using greedy sampling.")
|
||||
|
||||
self.n = n
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
|
@ -75,6 +27,55 @@ class SamplingParams:
|
|||
self.max_tokens = max_tokens
|
||||
self.logprobs = logprobs
|
||||
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
self._verity_beam_search()
|
||||
elif self.temperature == 0.0:
|
||||
# Zero temperature means greedy sampling.
|
||||
self._verify_greedy_sampling()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.n < 1:
|
||||
raise ValueError(f"n must be at least 1, got {self.n}.")
|
||||
if not -2.0 <= self.presence_penalty <= 2.0:
|
||||
raise ValueError("presence_penalty must be in [-2, 2], got "
|
||||
f"{self.presence_penalty}.")
|
||||
if not -2.0 <= self.frequency_penalty <= 2.0:
|
||||
raise ValueError("frequency_penalty must be in [-2, 2], got "
|
||||
f"{self.frequency_penalty}.")
|
||||
if self.temperature < 0.0:
|
||||
raise ValueError(
|
||||
f"temperature must be non-negative, got {self.temperature}.")
|
||||
if not 0.0 < self.top_p <= 1.0:
|
||||
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
||||
if self.top_k < -1 or self.top_k == 0:
|
||||
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
|
||||
f"got {self.top_k}.")
|
||||
if self.max_tokens < 1:
|
||||
raise ValueError(
|
||||
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
||||
if self.logprobs < 0:
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
|
||||
def _verity_beam_search(self) -> None:
|
||||
if self.n == 1:
|
||||
raise ValueError("n must be greater than 1 when using beam search.")
|
||||
if self.temperature > 0.0:
|
||||
raise ValueError("temperature must be 0 when using beam search.")
|
||||
if self.top_p < 1.0:
|
||||
raise ValueError("top_p must be 1 when using beam search.")
|
||||
if self.top_k != -1:
|
||||
raise ValueError("top_k must be -1 when using beam search.")
|
||||
|
||||
def _verify_greedy_sampling(self) -> None:
|
||||
if self.n > 1:
|
||||
raise ValueError("n must be 1 when using greedy sampling.")
|
||||
if self.top_p < 1.0:
|
||||
raise ValueError("top_p must be 1 when using greedy sampling.")
|
||||
if self.top_k != -1:
|
||||
raise ValueError("top_k must be -1 when using greedy sampling.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SamplingParams(n={self.n}, "
|
||||
f"presence_penalty={self.presence_penalty}, "
|
||||
|
|
Loading…
Reference in New Issue