mirror of https://github.com/vllm-project/vllm
[FIX] Minor bug fixes (#1035)
* [FIX] Minor bug fixes * Address review comments
This commit is contained in:
parent
ab019eea75
commit
f04908cae7
|
@ -82,8 +82,9 @@ class Sampler(nn.Module):
|
|||
# We use float32 for probabilities and log probabilities.
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities (before applying top-p and top-k).
|
||||
logprobs = torch.log(probs)
|
||||
# Compute the log probabilities.
|
||||
# Use log_softmax to ensure numerical stability.
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
return _sample(probs, logprobs, input_metadata)
|
||||
|
|
|
@ -350,7 +350,7 @@ class SequenceOutputs:
|
|||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SequenceOutputs):
|
||||
return NotImplementedError()
|
||||
raise NotImplementedError()
|
||||
return (self.parent_seq_id == other.parent_seq_id
|
||||
and self.output_token == other.output_token
|
||||
and self.logprobs == other.logprobs)
|
||||
|
|
Loading…
Reference in New Issue