mirror of https://github.com/vllm-project/vllm
[BugFix] Fix cuda graph for MLPSpeculator (#5875)
Co-authored-by: Abhinav Goyal <abhinav.goyal@flipkart.com>
This commit is contained in:
parent
b9e84259e9
commit
2110557dab
|
@ -52,7 +52,6 @@ if __name__ == "__main__":
|
|||
speculative_model="ibm-fms/llama-13b-accelerator",
|
||||
# These are currently required for MLPSpeculator decoding
|
||||
use_v2_block_manager=True,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
print("With speculation")
|
||||
|
|
|
@ -1020,10 +1020,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
assert model_input.sampling_metadata is not None
|
||||
indices = model_input.sampling_metadata.selected_token_indices
|
||||
if model_input.is_prompt:
|
||||
assert model_input.sampling_metadata is not None
|
||||
hidden_states = hidden_states.index_select(
|
||||
0, model_input.sampling_metadata.selected_token_indices)
|
||||
hidden_states = hidden_states.index_select(0, indices)
|
||||
elif decode_meta.use_cuda_graph:
|
||||
hidden_states = hidden_states[:len(indices)]
|
||||
|
||||
output.hidden_states = hidden_states
|
||||
|
||||
return output
|
||||
|
|
Loading…
Reference in New Issue