mirror of https://github.com/vllm-project/vllm
[Speculative Decoding] EAGLE Implementation with Top-1 proposer (#6830)
This commit is contained in:
parent
b3856bef7d
commit
a3fce56b88
|
@ -0,0 +1,268 @@
|
|||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various number of speculative tokens.
|
||||
|
||||
With those tests, we can say at least, EAGLE would not break the
|
||||
correctess for the target model outputs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_greedy_equality_correctness_test
|
||||
|
||||
# main model
|
||||
MAIN_MODEL = "JackFram/llama-68m"
|
||||
|
||||
# speculative model
|
||||
SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random"
|
||||
|
||||
# max. number of speculative tokens: this corresponds to
|
||||
# num_heads in the config.json of the speculator model.
|
||||
MAX_SPEC_TOKENS = 4
|
||||
|
||||
# precision
|
||||
PRECISION = "float32"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"enforce_eager": False,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality with cuda graph enabled and different
|
||||
batch sizes."""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 8,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify that eagle speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_disable_queue(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify that eagle speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
|
@ -70,8 +70,9 @@ PRECISION = "float32"
|
|||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
def test_medusa_e2e_greedy_correctness(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
|
@ -80,6 +81,49 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
|||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"enforce_eager": False,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality with cuda graph enabled and different
|
||||
batch sizes."""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
|
@ -116,10 +160,10 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
|||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
|
@ -165,9 +209,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
|||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify that mlp speculative decoding produces exact equality
|
||||
def test_medusa_different_k(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify that medusa speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
|
@ -208,9 +252,9 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
|
|||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify that mlp speculative decoding produces exact equality
|
||||
def test_medusa_disable_queue(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify that medusa speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
|
|
|
@ -6,7 +6,8 @@ import pytest
|
|||
import torch
|
||||
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
|
||||
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
|
||||
SamplerOutput, get_all_seq_ids)
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
@ -690,3 +691,36 @@ def test_use_draft_model_runner_advance_step():
|
|||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_expand_execute_model_request_sync_with_expand_hidden_states():
|
||||
"""
|
||||
In this test we verify that the logic for expanding the
|
||||
seq_group_metadata_list remains in sync with the expansion logic of
|
||||
the HiddenStates in _expand_execute_model_request.
|
||||
"""
|
||||
k = 5
|
||||
batch_size = 16
|
||||
seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15]
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
execute_model_request = ExecuteModelRequest(
|
||||
seq_group_metadata_list,
|
||||
previous_hidden_states=HiddenStates(
|
||||
torch.arange(batch_size), seq_group_metadata_list,
|
||||
torch.arange(batch_size, 2 * batch_size)))
|
||||
|
||||
expanded_execute_model_request, orig_seq_group_ids = MultiStepWorker.\
|
||||
_expand_execute_model_request(execute_model_request,
|
||||
seq_with_bonus_token_in_last_step)
|
||||
|
||||
all_seq_ids = torch.tensor(
|
||||
get_all_seq_ids(
|
||||
expanded_execute_model_request.seq_group_metadata_list))
|
||||
ref_expanded_hidden_states = all_seq_ids + batch_size
|
||||
ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size
|
||||
|
||||
assert (ref_expanded_hidden_states == expanded_execute_model_request.
|
||||
previous_hidden_states.hidden_states).all().item()
|
||||
|
|
|
@ -60,6 +60,7 @@ _GENERATION_MODELS = {
|
|||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
"EAGLEModel": ("eagle", "EAGLE"),
|
||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
||||
}
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
|
||||
|
||||
class EAGLE(nn.Module):
|
||||
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
|
||||
Reference implementation: https://github.com/SafeAILab/EAGLE
|
||||
|
||||
Differences from reference implementation:
|
||||
1. In reference, LlamaDecoderLayer implementation doesn't have
|
||||
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
|
||||
but we do as HF implementation also does.
|
||||
2. We allow any decoder layer to be used in EAGLE whereas in reference
|
||||
decoder layer is fixed to be LlamaDecoderLayer.
|
||||
3. We have an optional token_map which reduces draft vocab to most
|
||||
frequently used tokens to give some additional speed-up by reducing
|
||||
sampling overhead. This is disabled unless the checkpoint file has
|
||||
explicit token_map tensor and config has an optional attribute
|
||||
truncated_vocab_size < vocab_size. To use this technique, one has to find
|
||||
the top-k most frequent tokens in target dataset and add that as a tensor
|
||||
in the draft checkpoint (using key token_map). Also, the draft config
|
||||
needs to have truncated_vocab_size (=k) as an attribute."""
|
||||
|
||||
def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
architectures = getattr(self.config.model, "architectures", [])
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
|
||||
|
||||
self.model = model_cls(self.config.model, *args, **kwargs)
|
||||
self.fc = nn.Linear(config.model.hidden_size * 2,
|
||||
config.model.hidden_size,
|
||||
bias=False)
|
||||
|
||||
self.orig_vocab_size = config.vocab_size
|
||||
self.truncated_vocab_size = config.truncated_vocab_size
|
||||
self.unpadded_vocab_size = self.truncated_vocab_size
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=self.truncated_vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
)
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
self.truncated_vocab_size,
|
||||
logit_scale)
|
||||
|
||||
# Token map is a idx to token mapping to reduce the vocab size for
|
||||
# the draft model. Using smaller vocab size for draft, containing
|
||||
# only most frequent tokens reduces the speculation overhead. This
|
||||
# doesn't affect the acceptance rate much and thus gives more speed
|
||||
# -up. By default, this is disabled and is only used if the EAGLE
|
||||
# checkpoint file has token_map tensor.
|
||||
self.token_map = None
|
||||
|
||||
@property
|
||||
def sampler(self):
|
||||
return self.model.sampler
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
tok_embeds = self.model.model.embed_tokens(input_ids)
|
||||
inputs_embeds = self.fc(
|
||||
torch.cat([tok_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
|
||||
|
||||
hidden_states = self.model.model(
|
||||
input_ids=None,
|
||||
inputs_embeds=inputs_embeds,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
if self.token_map is not None:
|
||||
_logits = logits
|
||||
logits = -torch.inf * torch.ones(
|
||||
size=(*_logits.shape[:-1], self.orig_vocab_size),
|
||||
device=_logits.device,
|
||||
dtype=_logits.dtype)
|
||||
|
||||
logits[..., self.token_map] = _logits
|
||||
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
|
||||
# due to missing lm_head weights and its config being that of a
|
||||
# Llama model. Here's a compatible version with the same weights:
|
||||
# https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
|
||||
# Also, here's an example script for converting trained EAGLE
|
||||
# checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
|
||||
model_weights = {}
|
||||
for name, loaded_weight in weights:
|
||||
if name == "token_map":
|
||||
if self.config.truncated_vocab_size < self.config.vocab_size:
|
||||
self.token_map = nn.Parameter(loaded_weight,
|
||||
requires_grad=False)
|
||||
elif name.startswith("fc."):
|
||||
weight_loader = getattr(self.fc.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(self.fc.weight, loaded_weight)
|
||||
elif name.startswith("model.lm_head.") or name.startswith(
|
||||
"model.model."):
|
||||
model_weights[name.split("model.", 1)[-1]] = loaded_weight
|
||||
elif name.startswith("lm_head.") or name.startswith("model."):
|
||||
model_weights[name] = loaded_weight
|
||||
else:
|
||||
model_weights[f"model.{name}"] = loaded_weight
|
||||
|
||||
lm_head_weight = model_weights.pop("lm_head.weight")
|
||||
|
||||
if self.token_map is not None and\
|
||||
lm_head_weight.shape[0] > self.token_map.shape[0]:
|
||||
|
||||
lm_head_weight = lm_head_weight[self.token_map]
|
||||
|
||||
weight_loader = getattr(self.lm_head.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(self.lm_head.weight, lm_head_weight)
|
||||
|
||||
self.model.load_weights(model_weights.items())
|
|
@ -30,6 +30,19 @@ class ResidualBlock(nn.Module):
|
|||
|
||||
|
||||
class Medusa(nn.Module):
|
||||
"""This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774
|
||||
Reference implementation: https://github.com/FasterDecoding/Medusa
|
||||
|
||||
Differences from reference implementation:
|
||||
1. Currently this only supports generating proposals from top-1 tokens.
|
||||
2. We have an optional token_map which reduces draft vocab to most
|
||||
frequently used tokens to give some additional speed-up by reducing
|
||||
sampling overhead. This is disabled unless the checkpoint file has
|
||||
explicit token_map tensor and config has an optional attribute
|
||||
truncated_vocab_size < vocab_size. To use this technique, one has to find
|
||||
the top-k most frequent tokens in target dataset and add that as a tensor
|
||||
in the draft checkpoint (using key token_map). Also, the draft config
|
||||
needs to have truncated_vocab_size (=k) as an attribute."""
|
||||
|
||||
def __init__(self, config: MedusaConfig, **_) -> None:
|
||||
super().__init__()
|
||||
|
@ -57,6 +70,12 @@ class Medusa(nn.Module):
|
|||
self.truncated_vocab_size,
|
||||
logit_scale)
|
||||
|
||||
# Token map is a idx to token mapping to reduce the vocab size for
|
||||
# the draft model. Using smaller vocab size for draft, containing
|
||||
# only most frequent tokens reduces the speculation overhead. This
|
||||
# doesn't affect the acceptance rate much and thus gives more speed
|
||||
# -up. By default, this is disabled and is only used if the EAGLE
|
||||
# checkpoint file has token_map tensor.
|
||||
self.token_map = None
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
|
||||
|
|
|
@ -1092,6 +1092,10 @@ class SamplerOutput(
|
|||
# Optional last hidden states from the model.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
# Optional prefill hidden states from the model
|
||||
# (used for models like EAGLE).
|
||||
prefill_hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
# Time taken in the forward pass for this across all workers
|
||||
model_forward_time: Optional[float] = None
|
||||
|
||||
|
@ -1176,40 +1180,87 @@ class HiddenStates(msgspec.Struct, array_like=True,
|
|||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Hidden states corresponding to in-progress sequences.
|
||||
Used in speculative decoding to pass hidden states from
|
||||
the target model to the proposer model in the subsequent step.
|
||||
the target model to the proposer model.
|
||||
|
||||
seq_ids are the sequence ids of each entry of the batch
|
||||
dimension of the hidden_states tensor"""
|
||||
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
# Scorer hidden states. For prefill step, it is used for hidden states of
|
||||
# all tokens, whereas for decode step, it use used for last accepted tokens.
|
||||
hidden_states: torch.Tensor
|
||||
# The sequence group metadata list. Only needed for decode step.
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||
# Scorer hidden states of the 2nd last token proposed by the proposer (
|
||||
# irrespective of whether it was accepted or not). Only used for cases when
|
||||
# last proposed token is accepted (i.e., in case of bonus tokens). For the
|
||||
# case of no bonus tokens, these are ignored.
|
||||
second_last_token_hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
_seq_ids: List[int] = msgspec.field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
|
||||
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
|
||||
if self.seq_group_metadata_list is not None:
|
||||
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
|
||||
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
|
||||
|
||||
@property
|
||||
def seq_ids(self) -> List[int]:
|
||||
return self._seq_ids
|
||||
|
||||
def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
"""Update hidden states from target model invocation."""
|
||||
def update(self,
|
||||
hidden_states: torch.Tensor,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
second_last_token_hidden_states: Optional[torch.Tensor] = None):
|
||||
"""Update hidden states from target model invocation. Only used for
|
||||
decode steps"""
|
||||
assert len(seq_group_metadata_list) == len(hidden_states)
|
||||
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
|
||||
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
|
||||
|
||||
if self.second_last_token_hidden_states is not None:
|
||||
# Adding dummy hidden_states to this to maintain same shape
|
||||
self.second_last_token_hidden_states = torch.cat([
|
||||
self.second_last_token_hidden_states,
|
||||
torch.zeros_like(hidden_states)
|
||||
if second_last_token_hidden_states is None else
|
||||
second_last_token_hidden_states
|
||||
])
|
||||
|
||||
def prune(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
||||
"""Prune to provided list of sequence ids."""
|
||||
"""Prune to provided list of sequence ids. Only used for decode steps.
|
||||
"""
|
||||
# Currently this prunes all seq_ids not present in
|
||||
# seq_group_metadata_list which might cause problems where a sequence
|
||||
# may be "paused" then "resumed" later. This should only prune sequences
|
||||
# which are confirmed to be aborted.
|
||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
if seq_ids != self._seq_ids:
|
||||
# Batch contents changed - prune removed sequences.
|
||||
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
|
||||
self.hidden_states = self.hidden_states[index]
|
||||
if self.second_last_token_hidden_states is not None:
|
||||
self.second_last_token_hidden_states = self\
|
||||
.second_last_token_hidden_states[index]
|
||||
self._seq_ids = seq_ids
|
||||
|
||||
def expand_with_bonus_tokens(
|
||||
self, seq_with_bonus_token_in_last_step: set) -> None:
|
||||
"""Expand hidden states for sequences with bonus tokens. This is in
|
||||
alignment with `MultiStepWorker._expand_execute_model_request`."""
|
||||
if self.second_last_token_hidden_states is None \
|
||||
or not seq_with_bonus_token_in_last_step:
|
||||
return
|
||||
|
||||
index = []
|
||||
for seq_id in self._seq_ids:
|
||||
i = self._seq_ids.index(seq_id)
|
||||
if seq_id in seq_with_bonus_token_in_last_step:
|
||||
index.append(i + len(self._seq_ids))
|
||||
index.append(i)
|
||||
|
||||
self.hidden_states = torch.cat(
|
||||
[self.hidden_states, self.second_last_token_hidden_states])[index]
|
||||
|
||||
|
||||
class ExecuteModelRequest(
|
||||
msgspec.Struct,
|
||||
|
|
|
@ -203,6 +203,7 @@ class TP1DraftModelRunner(ModelRunner):
|
|||
self,
|
||||
model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
|
@ -280,13 +281,30 @@ class TP1DraftModelRunner(ModelRunner):
|
|||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = (self.graph_runners[model_input.virtual_engine]
|
||||
[graph_batch_size])
|
||||
|
||||
if previous_hidden_states is not None:
|
||||
hidden_states = torch.cat([
|
||||
previous_hidden_states,
|
||||
torch.empty([
|
||||
graph_batch_size - previous_hidden_states.shape[0],
|
||||
*previous_hidden_states.shape[1:]
|
||||
],
|
||||
dtype=previous_hidden_states.dtype,
|
||||
device=previous_hidden_states.device)
|
||||
])
|
||||
else:
|
||||
hidden_states = None
|
||||
else:
|
||||
model_executable = self.model
|
||||
hidden_states = previous_hidden_states
|
||||
|
||||
outputs: List[SamplerOutput] = []
|
||||
for step in range(num_steps):
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
|
||||
kwargs = {"previous_hidden_states": hidden_states} \
|
||||
if previous_hidden_states is not None else {}
|
||||
|
||||
# Run model
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
|
@ -296,6 +314,7 @@ class TP1DraftModelRunner(ModelRunner):
|
|||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
|
|
|
@ -4,8 +4,8 @@ from typing import Dict, List, Set, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SamplerOutput,
|
||||
SequenceData, SequenceGroupMetadata)
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
|
@ -157,6 +157,12 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
|||
|
||||
updated_execute_model_req.seq_group_metadata_list =\
|
||||
updated_seq_group_metadata_list
|
||||
|
||||
if isinstance(updated_execute_model_req.previous_hidden_states,
|
||||
HiddenStates):
|
||||
updated_execute_model_req.previous_hidden_states\
|
||||
.expand_with_bonus_tokens(seq_with_bonus_token_in_last_step)
|
||||
|
||||
return updated_execute_model_req, indices_of_original_sequence_groups
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -147,6 +147,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
draft_worker_kwargs[
|
||||
"model_runner_cls"] = TP1DraftModelRunner
|
||||
else:
|
||||
if draft_worker_kwargs[
|
||||
"model_config"].hf_config.model_type == "eagle":
|
||||
raise NotImplementedError(
|
||||
"EAGLE does not support TP > 1 yet")
|
||||
|
||||
allow_zero_draft_token_step = False
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
|
||||
|
@ -355,14 +360,34 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
execute_model_req)
|
||||
num_lookahead_slots = execute_model_req.num_lookahead_slots
|
||||
|
||||
# Speculative decoding is disabled in the following cases:
|
||||
# 1. Prefill phase: Speculative decoding is not
|
||||
# used during the prefill phase.
|
||||
# 2. Auto-disable enabled: The running queue size exceeds
|
||||
# the specified threshold.
|
||||
# 3. No request: There are no requests in the batch.
|
||||
# In any of these cases, the proposer and scorer workers
|
||||
# are called normally.
|
||||
no_spec = num_lookahead_slots == 0 or len(
|
||||
execute_model_req.seq_group_metadata_list
|
||||
) == 0 or disable_all_speculation
|
||||
|
||||
# Broadcast how many lookahead slots are scheduled for this step, and
|
||||
# whether all speculation is disabled, to all non-driver workers.
|
||||
|
||||
# This is required as if the number of draft model runs changes
|
||||
# dynamically, the non-driver workers won't know unless we perform a
|
||||
# communication to inform them.
|
||||
|
||||
# no_spec is used to signal non-driver worker about prefill vs decode
|
||||
# stage. This is needed to ensure that order of execution of proposer
|
||||
# and scorer is same in both driver and non-driver workers (i.e.,
|
||||
# scorer -> proposer for prefill and proposer -> scorer in decode). This
|
||||
# order is needed to support models like EAGLE that take scorer states
|
||||
# as inputs.
|
||||
broadcast_dict = dict(
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
no_spec=no_spec,
|
||||
disable_all_speculation=disable_all_speculation,
|
||||
)
|
||||
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
|
||||
|
@ -373,17 +398,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
self._maybe_disable_speculative_tokens(
|
||||
disable_all_speculation, execute_model_req.seq_group_metadata_list)
|
||||
|
||||
# Speculative decoding is disabled in the following cases:
|
||||
# 1. Prefill phase: Speculative decoding is not
|
||||
# used during the prefill phase.
|
||||
# 2. Auto-disable enabled: The running queue size exceeds
|
||||
# the specified threshold.
|
||||
# 3. No request: There are no requests in the batch.
|
||||
# In any of these cases, the proposer and scorer workers
|
||||
# are called normally.
|
||||
if num_lookahead_slots == 0 or len(
|
||||
execute_model_req.seq_group_metadata_list
|
||||
) == 0 or disable_all_speculation:
|
||||
if no_spec:
|
||||
return self._run_no_spec(execute_model_req,
|
||||
skip_proposer=disable_all_speculation)
|
||||
return self._run_speculative_decoding_step(execute_model_req,
|
||||
|
@ -464,8 +479,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
not called, meaning that the kv-cache in proposer for requests is not
|
||||
updated, so they cannot enable spec decode in the rest decoding.
|
||||
"""
|
||||
if not skip_proposer:
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
|
||||
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
||||
assert len(sampler_output) == 1
|
||||
|
@ -476,10 +489,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
if hidden_states is not None:
|
||||
if self.previous_hidden_states is None:
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
execute_model_req.seq_group_metadata_list, hidden_states)
|
||||
hidden_states, execute_model_req.seq_group_metadata_list)
|
||||
else:
|
||||
self.previous_hidden_states.update(
|
||||
execute_model_req.seq_group_metadata_list, hidden_states)
|
||||
hidden_states, execute_model_req.seq_group_metadata_list)
|
||||
|
||||
if not skip_proposer:
|
||||
# We prepare the prefill hidden states here so that there no
|
||||
# additional complexity in worker for spec_decode vs non_spec_decode
|
||||
# flow and execute_model doesn't need additional modifications.
|
||||
execute_model_req.previous_hidden_states = \
|
||||
prepare_prefill_hidden_states(
|
||||
sampler_output.prefill_hidden_states)
|
||||
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
|
||||
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
||||
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
||||
|
@ -507,15 +530,23 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
return False
|
||||
num_lookahead_slots = data["num_lookahead_slots"]
|
||||
|
||||
# Even if num_lookahead_slots is zero, we want to run the proposer model
|
||||
# as it may have KV.
|
||||
#
|
||||
# We run the proposer once per lookahead slot. In the future we should
|
||||
# delegate how many times it runs to the proposer.
|
||||
for _ in range(max(num_lookahead_slots, 1)):
|
||||
self.proposer_worker.execute_model()
|
||||
# In case of prefill, scorer_worker has to be run before proposer so
|
||||
# that the hidden states can be propagated to proposer when needed.
|
||||
if data["no_spec"]:
|
||||
self.scorer_worker.execute_model()
|
||||
|
||||
if not data["disable_all_speculation"]:
|
||||
# Even if num_lookahead_slots is zero, we want to run the
|
||||
# proposer model as it may have KV.
|
||||
#
|
||||
# We run the proposer once per lookahead slot. In the future we
|
||||
# should delegate how many times it runs to the proposer.
|
||||
for _ in range(max(num_lookahead_slots, 1)):
|
||||
self.proposer_worker.execute_model()
|
||||
|
||||
if not data["no_spec"]:
|
||||
self.scorer_worker.execute_model()
|
||||
|
||||
self.scorer_worker.execute_model()
|
||||
return True
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
||||
|
@ -546,6 +577,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
raise RuntimeError("Cannot handle cases where distributed draft "
|
||||
"workers generate no tokens")
|
||||
|
||||
execute_model_req.previous_hidden_states = None
|
||||
|
||||
with Timer() as scoring_timer:
|
||||
proposal_scores = self.scorer.score_proposals(
|
||||
execute_model_req,
|
||||
|
@ -651,10 +684,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
||||
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
|
||||
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
|
||||
second_last_token_hidden_states = hidden_states[:, -2] # b x d
|
||||
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
||||
# Store hidden states from target model for subsequent decode step
|
||||
self.previous_hidden_states = HiddenStates(seq_group_metadata_list,
|
||||
hidden_states)
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
hidden_states, seq_group_metadata_list,
|
||||
second_last_token_hidden_states)
|
||||
|
||||
return accepted_token_ids, logprobs
|
||||
|
||||
|
@ -951,3 +986,15 @@ def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
|
|||
(proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
|
||||
|
||||
return new_num_gpu_blocks
|
||||
|
||||
|
||||
def prepare_prefill_hidden_states(
|
||||
prefill_hidden_states: torch.Tensor) -> HiddenStates:
|
||||
# For prefill step in proposer, we run the model for N-1 tokens
|
||||
# because Nth token will be processed in the first decode step. For
|
||||
# N-1 tokens, the input should be 0:N-1 hidden states which should
|
||||
# be concatanated with 1:N token (since output of scorer has to be
|
||||
# the input for proposer). Therefore, we shift the hidden states to
|
||||
# align n-1th hidden state with nth token.
|
||||
return HiddenStates(prefill_hidden_states.roll(
|
||||
shifts=1, dims=0)) if prefill_hidden_states is not None else None
|
||||
|
|
|
@ -11,10 +11,11 @@ from transformers.models.auto.modeling_auto import (
|
|||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||
InternVLChatConfig, JAISConfig,
|
||||
MedusaConfig, MLPSpeculatorConfig,
|
||||
MPTConfig, NemotronConfig,
|
||||
RWConfig, UltravoxConfig)
|
||||
EAGLEConfig, InternVLChatConfig,
|
||||
JAISConfig, MedusaConfig,
|
||||
MLPSpeculatorConfig, MPTConfig,
|
||||
NemotronConfig, RWConfig,
|
||||
UltravoxConfig)
|
||||
|
||||
if VLLM_USE_MODELSCOPE:
|
||||
from modelscope import AutoConfig
|
||||
|
@ -32,6 +33,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|||
"jais": JAISConfig,
|
||||
"mlp_speculator": MLPSpeculatorConfig,
|
||||
"medusa": MedusaConfig,
|
||||
"eagle": EAGLEConfig,
|
||||
"internvl_chat": InternVLChatConfig,
|
||||
"nemotron": NemotronConfig,
|
||||
"ultravox": UltravoxConfig,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||
|
@ -20,6 +21,7 @@ __all__ = [
|
|||
"InternVLChatConfig",
|
||||
"JAISConfig",
|
||||
"MedusaConfig",
|
||||
"EAGLEConfig",
|
||||
"MLPSpeculatorConfig",
|
||||
"NemotronConfig",
|
||||
"UltravoxConfig",
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
|
||||
class EAGLEConfig(PretrainedConfig):
|
||||
model_type = "eagle"
|
||||
|
||||
def __init__(self,
|
||||
model: Union[PretrainedConfig, dict, None] = None,
|
||||
truncated_vocab_size: Optional[int] = None,
|
||||
**kwargs):
|
||||
|
||||
model_config = None if model is None else (AutoConfig.for_model(
|
||||
**model) if isinstance(model, dict) else model)
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if k != "architectures" and k != "model_type" and hasattr(
|
||||
model_config, k):
|
||||
setattr(model_config, k, v)
|
||||
|
||||
self.model = model_config
|
||||
|
||||
if self.model is None:
|
||||
self.truncated_vocab_size = None
|
||||
else:
|
||||
self.truncated_vocab_size = self.model.vocab_size if \
|
||||
truncated_vocab_size is None else truncated_vocab_size
|
||||
|
||||
if "architectures" not in kwargs:
|
||||
kwargs["architectures"] = ["EAGLEModel"]
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.model is not None:
|
||||
for k, v in self.model.to_dict().items():
|
||||
if not hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
**kwargs,
|
||||
) -> "EAGLEConfig":
|
||||
config_dict, kwargs = cls.get_config_dict(
|
||||
pretrained_model_name_or_path, **kwargs)
|
||||
return cls.from_dict(config_dict, **kwargs)
|
|
@ -1,5 +1,6 @@
|
|||
import dataclasses
|
||||
import gc
|
||||
import inspect
|
||||
import itertools
|
||||
import time
|
||||
import warnings
|
||||
|
@ -1192,6 +1193,18 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
|
||||
# Prepare dummy previous_hidden_states only if needed by the model.
|
||||
# This is used by draft models such as EAGLE.
|
||||
previous_hidden_states = None
|
||||
if "previous_hidden_states" in inspect.signature(
|
||||
self.model.forward).parameters:
|
||||
previous_hidden_states = torch.empty(
|
||||
[max_batch_size,
|
||||
self.model_config.get_hidden_size()],
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
|
||||
intermediate_inputs = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_inputs = self.model.make_empty_intermediate_tensors(
|
||||
|
@ -1264,6 +1277,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||
"stream":
|
||||
graph_capture_context.stream
|
||||
}
|
||||
if previous_hidden_states is not None:
|
||||
capture_inputs[
|
||||
"previous_hidden_states"] = previous_hidden_states[:
|
||||
batch_size]
|
||||
|
||||
if self.has_seqlen_agnostic:
|
||||
# Only used by Mamba-based models CUDA graph atm (Jamba)
|
||||
capture_inputs.update({
|
||||
|
@ -1462,6 +1480,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||
if model_input.is_prompt:
|
||||
hidden_states = hidden_or_intermediate_states.index_select(
|
||||
0, indices)
|
||||
output.prefill_hidden_states = hidden_or_intermediate_states
|
||||
elif decode_meta.use_cuda_graph:
|
||||
hidden_states = hidden_or_intermediate_states[:len(indices)]
|
||||
else:
|
||||
|
@ -1510,11 +1529,11 @@ class CUDAGraphRunner:
|
|||
# Note one iteration is not enough for torch.jit.script
|
||||
for _ in range(_NUM_WARMUP_ITERS):
|
||||
self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_inputs,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
@ -1523,11 +1542,11 @@ class CUDAGraphRunner:
|
|||
self._graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
||||
output_hidden_or_intermediate_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_inputs,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
if hidden_or_intermediate_states is not None:
|
||||
|
@ -1588,6 +1607,11 @@ class CUDAGraphRunner:
|
|||
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
|
||||
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
|
||||
**kwargs)
|
||||
|
||||
if "previous_hidden_states" in self.input_buffers:
|
||||
self.input_buffers["previous_hidden_states"].copy_(
|
||||
kwargs["previous_hidden_states"], non_blocking=True)
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
for key in intermediate_tensors.tensors:
|
||||
if key != "model_execute_time" and key != "model_forward_time":
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
|
@ -43,7 +45,7 @@ class MultiStepWorker(Worker):
|
|||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[BroadcastableModelInput, WorkerInput]:
|
||||
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Get the driver input and broadcast it to other workers.
|
||||
"""
|
||||
|
@ -85,7 +87,9 @@ class MultiStepWorker(Worker):
|
|||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
return model_input, worker_input
|
||||
# Retuning empty dict here to keep this compatible with
|
||||
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
|
||||
return model_input, worker_input, {}
|
||||
|
||||
def _prepare_last_sampled_token_ids_for_tp_workers(
|
||||
self,
|
||||
|
@ -130,7 +134,8 @@ class MultiStepWorker(Worker):
|
|||
def prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[Tuple[StatefulModelInput, WorkerInput]]:
|
||||
) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str,
|
||||
torch.Tensor]]]:
|
||||
"""
|
||||
Depending on the current state of the request and multi step worker,
|
||||
this method may skip the normal _prepare_model_input and
|
||||
|
@ -148,8 +153,8 @@ class MultiStepWorker(Worker):
|
|||
return None
|
||||
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
model_input, worker_input = self._get_driver_input_and_broadcast(
|
||||
execute_model_req)
|
||||
(model_input, worker_input,
|
||||
kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
if execute_model_req.is_first_multi_step:
|
||||
# cache the worker input and model input for the next steps
|
||||
|
@ -162,7 +167,7 @@ class MultiStepWorker(Worker):
|
|||
# loop
|
||||
if broadcast_data is None:
|
||||
return None
|
||||
model_input, worker_input = broadcast_data
|
||||
model_input, worker_input, kwargs = broadcast_data
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
virtual_engine = worker_input.virtual_engine
|
||||
if model_input.is_first_multi_step:
|
||||
|
@ -186,4 +191,4 @@ class MultiStepWorker(Worker):
|
|||
|
||||
assert model_input is not None
|
||||
assert worker_input is not None
|
||||
return model_input, worker_input
|
||||
return model_input, worker_input, kwargs
|
||||
|
|
|
@ -86,7 +86,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||
or (speculative_config.draft_model_config.model ==
|
||||
model_config.model) \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type
|
||||
not in ["medusa", "mlp_speculator"]) \
|
||||
not in ["medusa", "mlp_speculator", "eagle"]) \
|
||||
else {"return_hidden_states": True}
|
||||
|
||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||
|
|
|
@ -222,7 +222,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||
raise NotImplementedError
|
||||
|
||||
def _get_worker_input_from_broadcast(
|
||||
self) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]:
|
||||
self
|
||||
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
|
||||
str, torch.Tensor]]]:
|
||||
""" Get the worker input from the broadcasted tensor dict. """
|
||||
assert self.do_metadata_broadcast
|
||||
assert not self.is_driver_worker
|
||||
|
@ -235,11 +237,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
|
||||
broadcast_data))
|
||||
|
||||
return model_input, worker_input
|
||||
kwargs = extract_previous_hidden_states(broadcast_data)
|
||||
|
||||
return model_input, worker_input, kwargs
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[BroadcastableModelInput, WorkerInput]:
|
||||
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
||||
""" Get the driver input and broadcast it to other workers. """
|
||||
assert self.is_driver_worker
|
||||
|
||||
|
@ -251,17 +255,21 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
|
||||
kwargs = extract_previous_hidden_states(execute_model_req)
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_data.update(kwargs)
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
return model_input, worker_input
|
||||
return model_input, worker_input, kwargs
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]:
|
||||
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
|
||||
str, torch.Tensor]]]:
|
||||
"""
|
||||
Prepare the inputs to ModelRunner and workers.
|
||||
"""
|
||||
|
@ -291,7 +299,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||
if inputs is None:
|
||||
return None
|
||||
|
||||
model_input, worker_input = inputs
|
||||
model_input, worker_input, kwargs = inputs
|
||||
num_steps = worker_input.num_steps
|
||||
|
||||
self.execute_worker(worker_input)
|
||||
|
@ -312,9 +320,14 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||
"model_execute_time", torch.tensor(0)).item()
|
||||
|
||||
output = self.model_runner.execute_model(
|
||||
model_input, self.kv_cache[worker_input.virtual_engine]
|
||||
if self.kv_cache is not None else None, intermediate_tensors,
|
||||
num_steps)
|
||||
model_input=model_input,
|
||||
kv_caches=self.kv_cache[worker_input.virtual_engine]
|
||||
if self.kv_cache is not None else None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
num_steps=num_steps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
model_execute_time = time.perf_counter() - start_time
|
||||
if not get_pp_group().is_last_rank:
|
||||
# output is IntermediateTensors
|
||||
|
@ -360,9 +373,15 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||
if worker_input.num_seq_groups == 0:
|
||||
return []
|
||||
|
||||
kwargs = extract_previous_hidden_states(execute_model_req)
|
||||
|
||||
return self.model_runner.execute_model(
|
||||
model_input, self.kv_cache[worker_input.virtual_engine]
|
||||
if self.kv_cache is not None else None, intermediate_tensors)
|
||||
model_input=model_input,
|
||||
kv_caches=self.kv_cache[worker_input.virtual_engine]
|
||||
if self.kv_cache is not None else None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class WorkerWrapperBase:
|
||||
|
@ -439,3 +458,23 @@ class WorkerWrapperBase:
|
|||
"This might cause deadlock in distributed execution.")
|
||||
logger.exception(msg)
|
||||
raise e
|
||||
|
||||
|
||||
def extract_previous_hidden_states(
|
||||
data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
|
||||
Dict[str, torch.Tensor]:
|
||||
"""If data contains previous_hidden_states, extract it. This returns a dict
|
||||
which can be used directly as additional kwargs in any following
|
||||
execute_model calls. This is used in draft models like EAGLE."""
|
||||
output = {}
|
||||
|
||||
# When called from non-driver worker, data is dict but when called from
|
||||
# driver worker, data is ExecuteModelRequest.
|
||||
if isinstance(data, dict):
|
||||
if "previous_hidden_states" in data:
|
||||
output["previous_hidden_states"] = data["previous_hidden_states"]
|
||||
elif data.previous_hidden_states is not None:
|
||||
output["previous_hidden_states"] = data.previous_hidden_states\
|
||||
.hidden_states
|
||||
|
||||
return output
|
||||
|
|
Loading…
Reference in New Issue