[Speculative Decoding] EAGLE Implementation with Top-1 proposer (#6830)

This commit is contained in:
Abhinav Goyal 2024-08-22 15:12:24 +05:30 committed by GitHub
parent b3856bef7d
commit a3fce56b88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 854 additions and 83 deletions

View File

@ -0,0 +1,268 @@
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, EAGLE would not break the
correctess for the target model outputs.
"""
import pytest
from .conftest import run_greedy_equality_correctness_test
# main model
MAIN_MODEL = "JackFram/llama-68m"
# speculative model
SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random"
# max. number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS = 4
# precision
PRECISION = "float32"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality with different batch size."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that eagle speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_eagle_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that eagle speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@ -70,8 +70,9 @@ PRECISION = "float32"
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
def test_medusa_e2e_greedy_correctness(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality with different batch size."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
@ -80,6 +81,49 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
@ -116,10 +160,10 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
@ -165,9 +209,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
def test_medusa_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
@ -208,9 +252,9 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
def test_medusa_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""

View File

@ -6,7 +6,8 @@ import pytest
import torch
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
SamplerOutput, get_all_seq_ids)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
@ -690,3 +691,36 @@ def test_use_draft_model_runner_advance_step():
worker.execute_model(execute_model_req=execute_model_req)
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
assert len(call_args_list) == 1
@torch.inference_mode()
def test_expand_execute_model_request_sync_with_expand_hidden_states():
"""
In this test we verify that the logic for expanding the
seq_group_metadata_list remains in sync with the expansion logic of
the HiddenStates in _expand_execute_model_request.
"""
k = 5
batch_size = 16
seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15]
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_request = ExecuteModelRequest(
seq_group_metadata_list,
previous_hidden_states=HiddenStates(
torch.arange(batch_size), seq_group_metadata_list,
torch.arange(batch_size, 2 * batch_size)))
expanded_execute_model_request, orig_seq_group_ids = MultiStepWorker.\
_expand_execute_model_request(execute_model_request,
seq_with_bonus_token_in_last_step)
all_seq_ids = torch.tensor(
get_all_seq_ids(
expanded_execute_model_request.seq_group_metadata_list))
ref_expanded_hidden_states = all_seq_ids + batch_size
ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size
assert (ref_expanded_hidden_states == expanded_execute_model_request.
previous_hidden_states.hidden_states).all().item()

View File

@ -60,6 +60,7 @@ _GENERATION_MODELS = {
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MedusaModel": ("medusa", "Medusa"),
"EAGLEModel": ("eagle", "EAGLE"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
}

View File

@ -0,0 +1,161 @@
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs.eagle import EAGLEConfig
class EAGLE(nn.Module):
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
Reference implementation: https://github.com/SafeAILab/EAGLE
Differences from reference implementation:
1. In reference, LlamaDecoderLayer implementation doesn't have
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
but we do as HF implementation also does.
2. We allow any decoder layer to be used in EAGLE whereas in reference
decoder layer is fixed to be LlamaDecoderLayer.
3. We have an optional token_map which reduces draft vocab to most
frequently used tokens to give some additional speed-up by reducing
sampling overhead. This is disabled unless the checkpoint file has
explicit token_map tensor and config has an optional attribute
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None:
super().__init__()
self.config = config
architectures = getattr(self.config.model, "architectures", [])
model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
self.model = model_cls(self.config.model, *args, **kwargs)
self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size,
bias=False)
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.truncated_vocab_size,
logit_scale)
# Token map is a idx to token mapping to reduce the vocab size for
# the draft model. Using smaller vocab size for draft, containing
# only most frequent tokens reduces the speculation overhead. This
# doesn't affect the acceptance rate much and thus gives more speed
# -up. By default, this is disabled and is only used if the EAGLE
# checkpoint file has token_map tensor.
self.token_map = None
@property
def sampler(self):
return self.model.sampler
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
tok_embeds = self.model.model.embed_tokens(input_ids)
inputs_embeds = self.fc(
torch.cat([tok_embeds, previous_hidden_states], dim=-1))
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
hidden_states = self.model.model(
input_ids=None,
inputs_embeds=inputs_embeds,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
if self.token_map is not None:
_logits = logits
logits = -torch.inf * torch.ones(
size=(*_logits.shape[:-1], self.orig_vocab_size),
device=_logits.device,
dtype=_logits.dtype)
logits[..., self.token_map] = _logits
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
# due to missing lm_head weights and its config being that of a
# Llama model. Here's a compatible version with the same weights:
# https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
# Also, here's an example script for converting trained EAGLE
# checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
model_weights = {}
for name, loaded_weight in weights:
if name == "token_map":
if self.config.truncated_vocab_size < self.config.vocab_size:
self.token_map = nn.Parameter(loaded_weight,
requires_grad=False)
elif name.startswith("fc."):
weight_loader = getattr(self.fc.weight, "weight_loader",
default_weight_loader)
weight_loader(self.fc.weight, loaded_weight)
elif name.startswith("model.lm_head.") or name.startswith(
"model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight
elif name.startswith("lm_head.") or name.startswith("model."):
model_weights[name] = loaded_weight
else:
model_weights[f"model.{name}"] = loaded_weight
lm_head_weight = model_weights.pop("lm_head.weight")
if self.token_map is not None and\
lm_head_weight.shape[0] > self.token_map.shape[0]:
lm_head_weight = lm_head_weight[self.token_map]
weight_loader = getattr(self.lm_head.weight, "weight_loader",
default_weight_loader)
weight_loader(self.lm_head.weight, lm_head_weight)
self.model.load_weights(model_weights.items())

View File

@ -30,6 +30,19 @@ class ResidualBlock(nn.Module):
class Medusa(nn.Module):
"""This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774
Reference implementation: https://github.com/FasterDecoding/Medusa
Differences from reference implementation:
1. Currently this only supports generating proposals from top-1 tokens.
2. We have an optional token_map which reduces draft vocab to most
frequently used tokens to give some additional speed-up by reducing
sampling overhead. This is disabled unless the checkpoint file has
explicit token_map tensor and config has an optional attribute
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, config: MedusaConfig, **_) -> None:
super().__init__()
@ -57,6 +70,12 @@ class Medusa(nn.Module):
self.truncated_vocab_size,
logit_scale)
# Token map is a idx to token mapping to reduce the vocab size for
# the draft model. Using smaller vocab size for draft, containing
# only most frequent tokens reduces the speculation overhead. This
# doesn't affect the acceptance rate much and thus gives more speed
# -up. By default, this is disabled and is only used if the EAGLE
# checkpoint file has token_map tensor.
self.token_map = None
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:

View File

@ -1092,6 +1092,10 @@ class SamplerOutput(
# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None
# Optional prefill hidden states from the model
# (used for models like EAGLE).
prefill_hidden_states: Optional[torch.Tensor] = None
# Time taken in the forward pass for this across all workers
model_forward_time: Optional[float] = None
@ -1176,40 +1180,87 @@ class HiddenStates(msgspec.Struct, array_like=True,
omit_defaults=True): # type: ignore[call-arg]
"""Hidden states corresponding to in-progress sequences.
Used in speculative decoding to pass hidden states from
the target model to the proposer model in the subsequent step.
the target model to the proposer model.
seq_ids are the sequence ids of each entry of the batch
dimension of the hidden_states tensor"""
seq_group_metadata_list: List[SequenceGroupMetadata]
# Scorer hidden states. For prefill step, it is used for hidden states of
# all tokens, whereas for decode step, it use used for last accepted tokens.
hidden_states: torch.Tensor
# The sequence group metadata list. Only needed for decode step.
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
# Scorer hidden states of the 2nd last token proposed by the proposer (
# irrespective of whether it was accepted or not). Only used for cases when
# last proposed token is accepted (i.e., in case of bonus tokens). For the
# case of no bonus tokens, these are ignored.
second_last_token_hidden_states: Optional[torch.Tensor] = None
_seq_ids: List[int] = msgspec.field(default_factory=list)
def __post_init__(self):
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
if self.seq_group_metadata_list is not None:
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
@property
def seq_ids(self) -> List[int]:
return self._seq_ids
def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
hidden_states: torch.Tensor) -> None:
"""Update hidden states from target model invocation."""
def update(self,
hidden_states: torch.Tensor,
seq_group_metadata_list: List[SequenceGroupMetadata],
second_last_token_hidden_states: Optional[torch.Tensor] = None):
"""Update hidden states from target model invocation. Only used for
decode steps"""
assert len(seq_group_metadata_list) == len(hidden_states)
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
if self.second_last_token_hidden_states is not None:
# Adding dummy hidden_states to this to maintain same shape
self.second_last_token_hidden_states = torch.cat([
self.second_last_token_hidden_states,
torch.zeros_like(hidden_states)
if second_last_token_hidden_states is None else
second_last_token_hidden_states
])
def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
"""Prune to provided list of sequence ids."""
"""Prune to provided list of sequence ids. Only used for decode steps.
"""
# Currently this prunes all seq_ids not present in
# seq_group_metadata_list which might cause problems where a sequence
# may be "paused" then "resumed" later. This should only prune sequences
# which are confirmed to be aborted.
seq_ids = get_all_seq_ids(seq_group_metadata_list)
if seq_ids != self._seq_ids:
# Batch contents changed - prune removed sequences.
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
self.hidden_states = self.hidden_states[index]
if self.second_last_token_hidden_states is not None:
self.second_last_token_hidden_states = self\
.second_last_token_hidden_states[index]
self._seq_ids = seq_ids
def expand_with_bonus_tokens(
self, seq_with_bonus_token_in_last_step: set) -> None:
"""Expand hidden states for sequences with bonus tokens. This is in
alignment with `MultiStepWorker._expand_execute_model_request`."""
if self.second_last_token_hidden_states is None \
or not seq_with_bonus_token_in_last_step:
return
index = []
for seq_id in self._seq_ids:
i = self._seq_ids.index(seq_id)
if seq_id in seq_with_bonus_token_in_last_step:
index.append(i + len(self._seq_ids))
index.append(i)
self.hidden_states = torch.cat(
[self.hidden_states, self.second_last_token_hidden_states])[index]
class ExecuteModelRequest(
msgspec.Struct,

View File

@ -203,6 +203,7 @@ class TP1DraftModelRunner(ModelRunner):
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
@ -280,13 +281,30 @@ class TP1DraftModelRunner(ModelRunner):
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (self.graph_runners[model_input.virtual_engine]
[graph_batch_size])
if previous_hidden_states is not None:
hidden_states = torch.cat([
previous_hidden_states,
torch.empty([
graph_batch_size - previous_hidden_states.shape[0],
*previous_hidden_states.shape[1:]
],
dtype=previous_hidden_states.dtype,
device=previous_hidden_states.device)
])
else:
hidden_states = None
else:
model_executable = self.model
hidden_states = previous_hidden_states
outputs: List[SamplerOutput] = []
for step in range(num_steps):
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
kwargs = {"previous_hidden_states": hidden_states} \
if previous_hidden_states is not None else {}
# Run model
hidden_states = model_executable(
input_ids=model_input.input_tokens,
@ -296,6 +314,7 @@ class TP1DraftModelRunner(ModelRunner):
intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
device=self.device),
**kwargs,
)
# Compute the logits.

View File

@ -4,8 +4,8 @@ from typing import Dict, List, Set, Tuple
import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SamplerOutput,
SequenceData, SequenceGroupMetadata)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
@ -157,6 +157,12 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
updated_execute_model_req.seq_group_metadata_list =\
updated_seq_group_metadata_list
if isinstance(updated_execute_model_req.previous_hidden_states,
HiddenStates):
updated_execute_model_req.previous_hidden_states\
.expand_with_bonus_tokens(seq_with_bonus_token_in_last_step)
return updated_execute_model_req, indices_of_original_sequence_groups
@staticmethod

View File

@ -147,6 +147,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
if draft_worker_kwargs[
"model_config"].hf_config.model_type == "eagle":
raise NotImplementedError(
"EAGLE does not support TP > 1 yet")
allow_zero_draft_token_step = False
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
@ -355,14 +360,34 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req)
num_lookahead_slots = execute_model_req.num_lookahead_slots
# Speculative decoding is disabled in the following cases:
# 1. Prefill phase: Speculative decoding is not
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# 3. No request: There are no requests in the batch.
# In any of these cases, the proposer and scorer workers
# are called normally.
no_spec = num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list
) == 0 or disable_all_speculation
# Broadcast how many lookahead slots are scheduled for this step, and
# whether all speculation is disabled, to all non-driver workers.
# This is required as if the number of draft model runs changes
# dynamically, the non-driver workers won't know unless we perform a
# communication to inform them.
# no_spec is used to signal non-driver worker about prefill vs decode
# stage. This is needed to ensure that order of execution of proposer
# and scorer is same in both driver and non-driver workers (i.e.,
# scorer -> proposer for prefill and proposer -> scorer in decode). This
# order is needed to support models like EAGLE that take scorer states
# as inputs.
broadcast_dict = dict(
num_lookahead_slots=num_lookahead_slots,
no_spec=no_spec,
disable_all_speculation=disable_all_speculation,
)
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
@ -373,17 +398,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._maybe_disable_speculative_tokens(
disable_all_speculation, execute_model_req.seq_group_metadata_list)
# Speculative decoding is disabled in the following cases:
# 1. Prefill phase: Speculative decoding is not
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# 3. No request: There are no requests in the batch.
# In any of these cases, the proposer and scorer workers
# are called normally.
if num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list
) == 0 or disable_all_speculation:
if no_spec:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation)
return self._run_speculative_decoding_step(execute_model_req,
@ -464,8 +479,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
if not skip_proposer:
self.proposer_worker.execute_model(execute_model_req)
sampler_output = self.scorer_worker.execute_model(execute_model_req)
assert len(sampler_output) == 1
@ -476,10 +489,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if hidden_states is not None:
if self.previous_hidden_states is None:
self.previous_hidden_states = HiddenStates(
execute_model_req.seq_group_metadata_list, hidden_states)
hidden_states, execute_model_req.seq_group_metadata_list)
else:
self.previous_hidden_states.update(
execute_model_req.seq_group_metadata_list, hidden_states)
hidden_states, execute_model_req.seq_group_metadata_list)
if not skip_proposer:
# We prepare the prefill hidden states here so that there no
# additional complexity in worker for spec_decode vs non_spec_decode
# flow and execute_model doesn't need additional modifications.
execute_model_req.previous_hidden_states = \
prepare_prefill_hidden_states(
sampler_output.prefill_hidden_states)
self.proposer_worker.execute_model(execute_model_req)
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
execute_model_req=execute_model_req, sampler_output=sampler_output)
@ -507,15 +530,23 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
return False
num_lookahead_slots = data["num_lookahead_slots"]
# Even if num_lookahead_slots is zero, we want to run the proposer model
# as it may have KV.
#
# We run the proposer once per lookahead slot. In the future we should
# delegate how many times it runs to the proposer.
for _ in range(max(num_lookahead_slots, 1)):
self.proposer_worker.execute_model()
# In case of prefill, scorer_worker has to be run before proposer so
# that the hidden states can be propagated to proposer when needed.
if data["no_spec"]:
self.scorer_worker.execute_model()
if not data["disable_all_speculation"]:
# Even if num_lookahead_slots is zero, we want to run the
# proposer model as it may have KV.
#
# We run the proposer once per lookahead slot. In the future we
# should delegate how many times it runs to the proposer.
for _ in range(max(num_lookahead_slots, 1)):
self.proposer_worker.execute_model()
if not data["no_spec"]:
self.scorer_worker.execute_model()
self.scorer_worker.execute_model()
return True
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
@ -546,6 +577,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
raise RuntimeError("Cannot handle cases where distributed draft "
"workers generate no tokens")
execute_model_req.previous_hidden_states = None
with Timer() as scoring_timer:
proposal_scores = self.scorer.score_proposals(
execute_model_req,
@ -651,10 +684,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
second_last_token_hidden_states = hidden_states[:, -2] # b x d
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
# Store hidden states from target model for subsequent decode step
self.previous_hidden_states = HiddenStates(seq_group_metadata_list,
hidden_states)
self.previous_hidden_states = HiddenStates(
hidden_states, seq_group_metadata_list,
second_last_token_hidden_states)
return accepted_token_ids, logprobs
@ -951,3 +986,15 @@ def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
(proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
return new_num_gpu_blocks
def prepare_prefill_hidden_states(
prefill_hidden_states: torch.Tensor) -> HiddenStates:
# For prefill step in proposer, we run the model for N-1 tokens
# because Nth token will be processed in the first decode step. For
# N-1 tokens, the input should be 0:N-1 hidden states which should
# be concatanated with 1:N token (since output of scorer has to be
# the input for proposer). Therefore, we shift the hidden states to
# align n-1th hidden state with nth token.
return HiddenStates(prefill_hidden_states.roll(
shifts=1, dims=0)) if prefill_hidden_states is not None else None

View File

@ -11,10 +11,11 @@ from transformers.models.auto.modeling_auto import (
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
InternVLChatConfig, JAISConfig,
MedusaConfig, MLPSpeculatorConfig,
MPTConfig, NemotronConfig,
RWConfig, UltravoxConfig)
EAGLEConfig, InternVLChatConfig,
JAISConfig, MedusaConfig,
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, RWConfig,
UltravoxConfig)
if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
@ -32,6 +33,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"jais": JAISConfig,
"mlp_speculator": MLPSpeculatorConfig,
"medusa": MedusaConfig,
"eagle": EAGLEConfig,
"internvl_chat": InternVLChatConfig,
"nemotron": NemotronConfig,
"ultravox": UltravoxConfig,

View File

@ -1,5 +1,6 @@
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.dbrx import DbrxConfig
from vllm.transformers_utils.configs.eagle import EAGLEConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
@ -20,6 +21,7 @@ __all__ = [
"InternVLChatConfig",
"JAISConfig",
"MedusaConfig",
"EAGLEConfig",
"MLPSpeculatorConfig",
"NemotronConfig",
"UltravoxConfig",

View File

@ -0,0 +1,49 @@
import os
from typing import Optional, Union
from transformers import AutoConfig, PretrainedConfig
class EAGLEConfig(PretrainedConfig):
model_type = "eagle"
def __init__(self,
model: Union[PretrainedConfig, dict, None] = None,
truncated_vocab_size: Optional[int] = None,
**kwargs):
model_config = None if model is None else (AutoConfig.for_model(
**model) if isinstance(model, dict) else model)
for k, v in kwargs.items():
if k != "architectures" and k != "model_type" and hasattr(
model_config, k):
setattr(model_config, k, v)
self.model = model_config
if self.model is None:
self.truncated_vocab_size = None
else:
self.truncated_vocab_size = self.model.vocab_size if \
truncated_vocab_size is None else truncated_vocab_size
if "architectures" not in kwargs:
kwargs["architectures"] = ["EAGLEModel"]
super().__init__(**kwargs)
if self.model is not None:
for k, v in self.model.to_dict().items():
if not hasattr(self, k):
setattr(self, k, v)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs,
) -> "EAGLEConfig":
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs)
return cls.from_dict(config_dict, **kwargs)

View File

@ -1,5 +1,6 @@
import dataclasses
import gc
import inspect
import itertools
import time
import warnings
@ -1192,6 +1193,18 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
# Prepare dummy previous_hidden_states only if needed by the model.
# This is used by draft models such as EAGLE.
previous_hidden_states = None
if "previous_hidden_states" in inspect.signature(
self.model.forward).parameters:
previous_hidden_states = torch.empty(
[max_batch_size,
self.model_config.get_hidden_size()],
dtype=self.model_config.dtype,
device=self.device)
intermediate_inputs = None
if not get_pp_group().is_first_rank:
intermediate_inputs = self.model.make_empty_intermediate_tensors(
@ -1264,6 +1277,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"stream":
graph_capture_context.stream
}
if previous_hidden_states is not None:
capture_inputs[
"previous_hidden_states"] = previous_hidden_states[:
batch_size]
if self.has_seqlen_agnostic:
# Only used by Mamba-based models CUDA graph atm (Jamba)
capture_inputs.update({
@ -1462,6 +1480,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if model_input.is_prompt:
hidden_states = hidden_or_intermediate_states.index_select(
0, indices)
output.prefill_hidden_states = hidden_or_intermediate_states
elif decode_meta.use_cuda_graph:
hidden_states = hidden_or_intermediate_states[:len(indices)]
else:
@ -1510,11 +1529,11 @@ class CUDAGraphRunner:
# Note one iteration is not enough for torch.jit.script
for _ in range(_NUM_WARMUP_ITERS):
self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_inputs,
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_inputs,
**kwargs,
)
torch.cuda.synchronize()
@ -1523,11 +1542,11 @@ class CUDAGraphRunner:
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
output_hidden_or_intermediate_states = self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_inputs,
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_inputs,
**kwargs,
)
if hidden_or_intermediate_states is not None:
@ -1588,6 +1607,11 @@ class CUDAGraphRunner:
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
**kwargs)
if "previous_hidden_states" in self.input_buffers:
self.input_buffers["previous_hidden_states"].copy_(
kwargs["previous_hidden_states"], non_blocking=True)
if intermediate_tensors is not None:
for key in intermediate_tensors.tensors:
if key != "model_execute_time" and key != "model_forward_time":

View File

@ -1,5 +1,7 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import torch
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.sequence import ExecuteModelRequest, SamplerOutput
@ -43,7 +45,7 @@ class MultiStepWorker(Worker):
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[BroadcastableModelInput, WorkerInput]:
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
"""
Get the driver input and broadcast it to other workers.
"""
@ -85,7 +87,9 @@ class MultiStepWorker(Worker):
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
return model_input, worker_input
# Retuning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return model_input, worker_input, {}
def _prepare_last_sampled_token_ids_for_tp_workers(
self,
@ -130,7 +134,8 @@ class MultiStepWorker(Worker):
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[StatefulModelInput, WorkerInput]]:
) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str,
torch.Tensor]]]:
"""
Depending on the current state of the request and multi step worker,
this method may skip the normal _prepare_model_input and
@ -148,8 +153,8 @@ class MultiStepWorker(Worker):
return None
virtual_engine = execute_model_req.virtual_engine
model_input, worker_input = self._get_driver_input_and_broadcast(
execute_model_req)
(model_input, worker_input,
kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
assert isinstance(model_input, StatefulModelInput)
if execute_model_req.is_first_multi_step:
# cache the worker input and model input for the next steps
@ -162,7 +167,7 @@ class MultiStepWorker(Worker):
# loop
if broadcast_data is None:
return None
model_input, worker_input = broadcast_data
model_input, worker_input, kwargs = broadcast_data
assert isinstance(model_input, StatefulModelInput)
virtual_engine = worker_input.virtual_engine
if model_input.is_first_multi_step:
@ -186,4 +191,4 @@ class MultiStepWorker(Worker):
assert model_input is not None
assert worker_input is not None
return model_input, worker_input
return model_input, worker_input, kwargs

View File

@ -86,7 +86,7 @@ class Worker(LocalOrDistributedWorkerBase):
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator"]) \
not in ["medusa", "mlp_speculator", "eagle"]) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner

View File

@ -222,7 +222,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
raise NotImplementedError
def _get_worker_input_from_broadcast(
self) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]:
self
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
str, torch.Tensor]]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
@ -235,11 +237,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
broadcast_data))
return model_input, worker_input
kwargs = extract_previous_hidden_states(broadcast_data)
return model_input, worker_input, kwargs
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[BroadcastableModelInput, WorkerInput]:
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
""" Get the driver input and broadcast it to other workers. """
assert self.is_driver_worker
@ -251,17 +255,21 @@ class LocalOrDistributedWorkerBase(WorkerBase):
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
kwargs = extract_previous_hidden_states(execute_model_req)
if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
broadcast_data.update(kwargs)
broadcast_tensor_dict(broadcast_data, src=0)
return model_input, worker_input
return model_input, worker_input, kwargs
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]:
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
str, torch.Tensor]]]:
"""
Prepare the inputs to ModelRunner and workers.
"""
@ -291,7 +299,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if inputs is None:
return None
model_input, worker_input = inputs
model_input, worker_input, kwargs = inputs
num_steps = worker_input.num_steps
self.execute_worker(worker_input)
@ -312,9 +320,14 @@ class LocalOrDistributedWorkerBase(WorkerBase):
"model_execute_time", torch.tensor(0)).item()
output = self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None, intermediate_tensors,
num_steps)
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
num_steps=num_steps,
**kwargs,
)
model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank:
# output is IntermediateTensors
@ -360,9 +373,15 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if worker_input.num_seq_groups == 0:
return []
kwargs = extract_previous_hidden_states(execute_model_req)
return self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None, intermediate_tensors)
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
**kwargs,
)
class WorkerWrapperBase:
@ -439,3 +458,23 @@ class WorkerWrapperBase:
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e
def extract_previous_hidden_states(
data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
Dict[str, torch.Tensor]:
"""If data contains previous_hidden_states, extract it. This returns a dict
which can be used directly as additional kwargs in any following
execute_model calls. This is used in draft models like EAGLE."""
output = {}
# When called from non-driver worker, data is dict but when called from
# driver worker, data is ExecuteModelRequest.
if isinstance(data, dict):
if "previous_hidden_states" in data:
output["previous_hidden_states"] = data["previous_hidden_states"]
elif data.previous_hidden_states is not None:
output["previous_hidden_states"] = data.previous_hidden_states\
.hidden_states
return output