[BugFix] Fix cuda graph for MLPSpeculator (#5875)

Co-authored-by: Abhinav Goyal <abhinav.goyal@flipkart.com>
This commit is contained in:
Nick Hill 2024-06-26 21:12:10 -07:00 committed by GitHub
parent b9e84259e9
commit 2110557dab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 4 deletions

View File

@ -52,7 +52,6 @@ if __name__ == "__main__":
speculative_model="ibm-fms/llama-13b-accelerator",
# These are currently required for MLPSpeculator decoding
use_v2_block_manager=True,
enforce_eager=True,
)
print("With speculation")

View File

@ -1020,10 +1020,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None
indices = model_input.sampling_metadata.selected_token_indices
if model_input.is_prompt:
assert model_input.sampling_metadata is not None
hidden_states = hidden_states.index_select(
0, model_input.sampling_metadata.selected_token_indices)
hidden_states = hidden_states.index_select(0, indices)
elif decode_meta.use_cuda_graph:
hidden_states = hidden_states[:len(indices)]
output.hidden_states = hidden_states
return output