[FIX] Minor bug fixes (#1035)

* [FIX] Minor bug fixes

* Address review comments
This commit is contained in:
Zhuohan Li 2023-09-13 16:38:12 -07:00 committed by GitHub
parent ab019eea75
commit f04908cae7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 3 deletions

View File

@ -82,8 +82,9 @@ class Sampler(nn.Module):
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p and top-k).
logprobs = torch.log(probs)
# Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
return _sample(probs, logprobs, input_metadata)

View File

@ -350,7 +350,7 @@ class SequenceOutputs:
def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs):
return NotImplementedError()
raise NotImplementedError()
return (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token
and self.logprobs == other.logprobs)