From f04908cae782e1a2404eb3e4f331718d311d1e0d Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 13 Sep 2023 16:38:12 -0700 Subject: [PATCH] [FIX] Minor bug fixes (#1035) * [FIX] Minor bug fixes * Address review comments --- vllm/model_executor/layers/sampler.py | 5 +++-- vllm/sequence.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 9dcfa42e2a..013b44060d 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -82,8 +82,9 @@ class Sampler(nn.Module): # We use float32 for probabilities and log probabilities. # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) - # Compute the log probabilities (before applying top-p and top-k). - logprobs = torch.log(probs) + # Compute the log probabilities. + # Use log_softmax to ensure numerical stability. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. return _sample(probs, logprobs, input_metadata) diff --git a/vllm/sequence.py b/vllm/sequence.py index 795397a373..eac3af2823 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -350,7 +350,7 @@ class SequenceOutputs: def __eq__(self, other: object) -> bool: if not isinstance(other, SequenceOutputs): - return NotImplementedError() + raise NotImplementedError() return (self.parent_seq_id == other.parent_seq_id and self.output_token == other.output_token and self.logprobs == other.logprobs)